diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/policy/__pycache__/dynamic_tf_policy.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/policy/__pycache__/dynamic_tf_policy.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..028d72d95713ea81d10f300697b8da1c39559435 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/policy/__pycache__/dynamic_tf_policy.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/policy/__pycache__/dynamic_tf_policy_v2.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/policy/__pycache__/dynamic_tf_policy_v2.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f3580f47b179f7cdf6463069f147979ea032ce23 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/policy/__pycache__/dynamic_tf_policy_v2.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/policy/__pycache__/eager_tf_policy.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/policy/__pycache__/eager_tf_policy.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..03e6150c331ac81c2169fe37f49d11bcc68bdeb2 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/policy/__pycache__/eager_tf_policy.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/policy/__pycache__/eager_tf_policy_v2.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/policy/__pycache__/eager_tf_policy_v2.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0c2bd751a9bf8a3fd74492dd188f74b1229aedb5 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/policy/__pycache__/eager_tf_policy_v2.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/policy/__pycache__/policy.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/policy/__pycache__/policy.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..34ab8212931c2a888d43c192b7fe5521f073b2db Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/policy/__pycache__/policy.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/policy/__pycache__/policy_map.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/policy/__pycache__/policy_map.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e646d2ea74b970647df18630f6072b7ef0b4811e Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/policy/__pycache__/policy_map.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/policy/__pycache__/policy_template.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/policy/__pycache__/policy_template.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..05632474e1a18544b14127a911c2473df1bb596d Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/policy/__pycache__/policy_template.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/policy/__pycache__/rnn_sequencing.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/policy/__pycache__/rnn_sequencing.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..09516348f79d0a199466ff336f4ef8dbab8cfdbb Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/policy/__pycache__/rnn_sequencing.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/policy/__pycache__/sample_batch.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/policy/__pycache__/sample_batch.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2932356893061954a0fcc94629f07d5ea778fe8f Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/policy/__pycache__/sample_batch.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/policy/__pycache__/tf_mixins.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/policy/__pycache__/tf_mixins.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..66330618a0a9dd83c597e0a30a4cfa307eb3c046 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/policy/__pycache__/tf_mixins.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/policy/__pycache__/tf_policy_template.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/policy/__pycache__/tf_policy_template.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6dde2f0401b5b89f06b42fddf016d05e65bd5614 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/policy/__pycache__/tf_policy_template.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/policy/__pycache__/torch_policy.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/policy/__pycache__/torch_policy.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9e05a1c0d450264eefac9ad73e11133a25849e0c Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/policy/__pycache__/torch_policy.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/policy/__pycache__/view_requirement.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/policy/__pycache__/view_requirement.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c9f6502b3abc07f5211ea8b5b6fddd4e729eefdf Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/policy/__pycache__/view_requirement.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/tune/__init__.py b/.venv/lib/python3.11/site-packages/ray/tune/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2d018c7f1e9d2474620336c4090b768f1c554ff1 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/tune/__init__.py @@ -0,0 +1,115 @@ +# isort: off +# Try import ray[tune] core requirements (defined in setup.py) +try: + import fsspec # noqa: F401 + import pandas # noqa: F401 + import pyarrow # noqa: F401 + import requests # noqa: F401 +except ImportError as exc: + raise ImportError( + "Can't import ray.tune as some dependencies are missing. " + 'Run `pip install "ray[tune]"` to fix.' + ) from exc +# isort: on + +from ray.air.result import Result +from ray.tune.analysis import ExperimentAnalysis +from ray.tune.callback import Callback +from ray.tune.context import TuneContext, get_context +from ray.tune.error import TuneError +from ray.tune.execution.placement_groups import PlacementGroupFactory +from ray.tune.experiment import Experiment +from ray.tune.impl.config import CheckpointConfig, FailureConfig, RunConfig +from ray.tune.progress_reporter import ( + CLIReporter, + JupyterNotebookReporter, + ProgressReporter, +) +from ray.tune.registry import register_env, register_trainable +from ray.tune.result_grid import ResultGrid +from ray.tune.schedulers import create_scheduler +from ray.tune.search import create_searcher, grid_search +from ray.tune.search.sample import ( + choice, + lograndint, + loguniform, + qlograndint, + qloguniform, + qrandint, + qrandn, + quniform, + randint, + randn, + sample_from, + uniform, +) +from ray.tune.stopper import Stopper +from ray.tune.syncer import SyncConfig +from ray.tune.trainable import Trainable +from ray.tune.trainable.trainable_fn_utils import Checkpoint, get_checkpoint, report +from ray.tune.trainable.util import with_parameters, with_resources +from ray.tune.tune import run, run_experiments +from ray.tune.tune_config import ResumeConfig, TuneConfig +from ray.tune.tuner import Tuner + +__all__ = [ + "Trainable", + "Callback", + "TuneError", + "grid_search", + "register_env", + "register_trainable", + "run", + "run_experiments", + "with_parameters", + "with_resources", + "Stopper", + "Experiment", + "sample_from", + "uniform", + "quniform", + "choice", + "randint", + "lograndint", + "qrandint", + "qlograndint", + "randn", + "qrandn", + "loguniform", + "qloguniform", + "ExperimentAnalysis", + "CLIReporter", + "JupyterNotebookReporter", + "ProgressReporter", + "ResultGrid", + "create_searcher", + "create_scheduler", + "PlacementGroupFactory", + "Tuner", + "TuneConfig", + "ResumeConfig", + "RunConfig", + "CheckpointConfig", + "FailureConfig", + "Result", + "Checkpoint", + "get_checkpoint", + "report", + "get_context", + "TuneContext", + # TODO(justinvyu): [Deprecated] + "SyncConfig", +] + +report.__module__ = "ray.tune" +get_checkpoint.__module__ = "ray.tune" +get_context.__module__ = "ray.tune" +TuneContext.__module__ = "ray.tune" +Checkpoint.__module__ = "ray.tune" +Result.__module__ = "ray.tune" +RunConfig.__module__ = "ray.tune" +CheckpointConfig.__module__ = "ray.tune" +FailureConfig.__module__ = "ray.tune" + + +# DO NOT ADD ANYTHING AFTER THIS LINE. diff --git a/.venv/lib/python3.11/site-packages/ray/tune/automl/__init__.py b/.venv/lib/python3.11/site-packages/ray/tune/automl/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..24a70c9ef10e6e09db27739525f15df94ce99dd5 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/tune/automl/__init__.py @@ -0,0 +1 @@ +raise DeprecationWarning("`ray.tune.automl` is deprecated in Ray 2.6.") diff --git a/.venv/lib/python3.11/site-packages/ray/tune/automl/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/tune/automl/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6525ffb0ae6a9ea99b7a308c2dbaf9e83d486953 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/tune/automl/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/tune/callback.py b/.venv/lib/python3.11/site-packages/ray/tune/callback.py new file mode 100644 index 0000000000000000000000000000000000000000..3134ddba8082b4c0b45df9d0a15adc246a3b2d24 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/tune/callback.py @@ -0,0 +1,512 @@ +import glob +import warnings +from abc import ABCMeta +from pathlib import Path +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple + +import ray.tune +from ray.tune.utils.util import _atomic_save, _load_newest_checkpoint +from ray.util.annotations import DeveloperAPI, PublicAPI + +if TYPE_CHECKING: + from ray.tune.experiment import Trial + from ray.tune.stopper import Stopper + + +class _CallbackMeta(ABCMeta): + """A helper metaclass to ensure container classes (e.g. CallbackList) have + implemented all the callback methods (e.g. `on_*`). + """ + + def __new__(mcs, name: str, bases: Tuple[type], attrs: Dict[str, Any]) -> type: + cls = super().__new__(mcs, name, bases, attrs) + + if mcs.need_check(cls, name, bases, attrs): + mcs.check(cls, name, bases, attrs) + + return cls + + @classmethod + def need_check( + mcs, cls: type, name: str, bases: Tuple[type], attrs: Dict[str, Any] + ) -> bool: + return attrs.get("IS_CALLBACK_CONTAINER", False) + + @classmethod + def check( + mcs, cls: type, name: str, bases: Tuple[type], attrs: Dict[str, Any] + ) -> None: + methods = set() + for base in bases: + methods.update( + attr_name + for attr_name, attr in vars(base).items() + if mcs.need_override_by_subclass(attr_name, attr) + ) + overridden = { + attr_name + for attr_name, attr in attrs.items() + if mcs.need_override_by_subclass(attr_name, attr) + } + missing = methods.difference(overridden) + if missing: + raise TypeError( + f"Found missing callback method: {missing} " + f"in class {cls.__module__}.{cls.__qualname__}." + ) + + @classmethod + def need_override_by_subclass(mcs, attr_name: str, attr: Any) -> bool: + return ( + ( + attr_name.startswith("on_") + and not attr_name.startswith("on_trainer_init") + ) + or attr_name == "setup" + ) and callable(attr) + + +@PublicAPI(stability="beta") +class Callback(metaclass=_CallbackMeta): + """Tune base callback that can be extended and passed to a ``TrialRunner`` + + Tune callbacks are called from within the ``TrialRunner`` class. There are + several hooks that can be used, all of which are found in the submethod + definitions of this base class. + + The parameters passed to the ``**info`` dict vary between hooks. The + parameters passed are described in the docstrings of the methods. + + This example will print a metric each time a result is received: + + .. testcode:: + + from ray import train, tune + from ray.tune import Callback + + + class MyCallback(Callback): + def on_trial_result(self, iteration, trials, trial, result, + **info): + print(f"Got result: {result['metric']}") + + + def train_func(config): + for i in range(10): + tune.report(metric=i) + + tuner = tune.Tuner( + train_func, + run_config=train.RunConfig( + callbacks=[MyCallback()] + ) + ) + tuner.fit() + + .. testoutput:: + :hide: + + ... + """ + + # File templates for any artifacts written by this callback + # These files should live in the `trial.local_path` for each trial. + # TODO(ml-team): Make this more visible to users to override. Internal use for now. + _SAVED_FILE_TEMPLATES = [] + + # arguments here match Experiment.public_spec + def setup( + self, + stop: Optional["Stopper"] = None, + num_samples: Optional[int] = None, + total_num_samples: Optional[int] = None, + **info, + ): + """Called once at the very beginning of training. + + Any Callback setup should be added here (setting environment + variables, etc.) + + Arguments: + stop: Stopping criteria. + If ``time_budget_s`` was passed to ``train.RunConfig``, a + ``TimeoutStopper`` will be passed here, either by itself + or as a part of a ``CombinedStopper``. + num_samples: Number of times to sample from the + hyperparameter space. Defaults to 1. If `grid_search` is + provided as an argument, the grid will be repeated + `num_samples` of times. If this is -1, (virtually) infinite + samples are generated until a stopping condition is met. + total_num_samples: Total number of samples factoring + in grid search samplers. + **info: Kwargs dict for forward compatibility. + """ + pass + + def on_step_begin(self, iteration: int, trials: List["Trial"], **info): + """Called at the start of each tuning loop step. + + Arguments: + iteration: Number of iterations of the tuning loop. + trials: List of trials. + **info: Kwargs dict for forward compatibility. + """ + pass + + def on_step_end(self, iteration: int, trials: List["Trial"], **info): + """Called at the end of each tuning loop step. + + The iteration counter is increased before this hook is called. + + Arguments: + iteration: Number of iterations of the tuning loop. + trials: List of trials. + **info: Kwargs dict for forward compatibility. + """ + pass + + def on_trial_start( + self, iteration: int, trials: List["Trial"], trial: "Trial", **info + ): + """Called after starting a trial instance. + + Arguments: + iteration: Number of iterations of the tuning loop. + trials: List of trials. + trial: Trial that just has been started. + **info: Kwargs dict for forward compatibility. + + """ + pass + + def on_trial_restore( + self, iteration: int, trials: List["Trial"], trial: "Trial", **info + ): + """Called after restoring a trial instance. + + Arguments: + iteration: Number of iterations of the tuning loop. + trials: List of trials. + trial: Trial that just has been restored. + **info: Kwargs dict for forward compatibility. + """ + pass + + def on_trial_save( + self, iteration: int, trials: List["Trial"], trial: "Trial", **info + ): + """Called after receiving a checkpoint from a trial. + + Arguments: + iteration: Number of iterations of the tuning loop. + trials: List of trials. + trial: Trial that just saved a checkpoint. + **info: Kwargs dict for forward compatibility. + """ + pass + + def on_trial_result( + self, + iteration: int, + trials: List["Trial"], + trial: "Trial", + result: Dict, + **info, + ): + """Called after receiving a result from a trial. + + The search algorithm and scheduler are notified before this + hook is called. + + Arguments: + iteration: Number of iterations of the tuning loop. + trials: List of trials. + trial: Trial that just sent a result. + result: Result that the trial sent. + **info: Kwargs dict for forward compatibility. + """ + pass + + def on_trial_complete( + self, iteration: int, trials: List["Trial"], trial: "Trial", **info + ): + """Called after a trial instance completed. + + The search algorithm and scheduler are notified before this + hook is called. + + Arguments: + iteration: Number of iterations of the tuning loop. + trials: List of trials. + trial: Trial that just has been completed. + **info: Kwargs dict for forward compatibility. + """ + pass + + def on_trial_recover( + self, iteration: int, trials: List["Trial"], trial: "Trial", **info + ): + """Called after a trial instance failed (errored) but the trial is scheduled + for retry. + + The search algorithm and scheduler are not notified. + + Arguments: + iteration: Number of iterations of the tuning loop. + trials: List of trials. + trial: Trial that just has errored. + **info: Kwargs dict for forward compatibility. + """ + pass + + def on_trial_error( + self, iteration: int, trials: List["Trial"], trial: "Trial", **info + ): + """Called after a trial instance failed (errored). + + The search algorithm and scheduler are notified before this + hook is called. + + Arguments: + iteration: Number of iterations of the tuning loop. + trials: List of trials. + trial: Trial that just has errored. + **info: Kwargs dict for forward compatibility. + """ + pass + + def on_checkpoint( + self, + iteration: int, + trials: List["Trial"], + trial: "Trial", + checkpoint: "ray.tune.Checkpoint", + **info, + ): + """Called after a trial saved a checkpoint with Tune. + + Arguments: + iteration: Number of iterations of the tuning loop. + trials: List of trials. + trial: Trial that just has errored. + checkpoint: Checkpoint object that has been saved + by the trial. + **info: Kwargs dict for forward compatibility. + """ + pass + + def on_experiment_end(self, trials: List["Trial"], **info): + """Called after experiment is over and all trials have concluded. + + Arguments: + trials: List of trials. + **info: Kwargs dict for forward compatibility. + """ + pass + + def get_state(self) -> Optional[Dict]: + """Get the state of the callback. + + This method should be implemented by subclasses to return a dictionary + representation of the object's current state. + + This is called automatically by Tune to periodically checkpoint callback state. + Upon :ref:`Tune experiment restoration `, + callback state will be restored via :meth:`~ray.tune.Callback.set_state`. + + .. testcode:: + + from typing import Dict, List, Optional + + from ray.tune import Callback + from ray.tune.experiment import Trial + + class MyCallback(Callback): + def __init__(self): + self._trial_ids = set() + + def on_trial_start( + self, iteration: int, trials: List["Trial"], trial: "Trial", **info + ): + self._trial_ids.add(trial.trial_id) + + def get_state(self) -> Optional[Dict]: + return {"trial_ids": self._trial_ids.copy()} + + def set_state(self, state: Dict) -> Optional[Dict]: + self._trial_ids = state["trial_ids"] + + Returns: + dict: State of the callback. Should be `None` if the callback does not + have any state to save (this is the default). + """ + return None + + def set_state(self, state: Dict): + """Set the state of the callback. + + This method should be implemented by subclasses to restore the callback's + state based on the given dict state. + + This is used automatically by Tune to restore checkpoint callback state + on :ref:`Tune experiment restoration `. + + See :meth:`~ray.tune.Callback.get_state` for an example implementation. + + Args: + state: State of the callback. + """ + pass + + +@DeveloperAPI +class CallbackList(Callback): + """Call multiple callbacks at once.""" + + IS_CALLBACK_CONTAINER = True + CKPT_FILE_TMPL = "callback-states-{}.pkl" + + def __init__(self, callbacks: List[Callback]): + self._callbacks = callbacks + + def setup(self, **info): + for callback in self._callbacks: + try: + callback.setup(**info) + except TypeError as e: + if "argument" in str(e): + warnings.warn( + "Please update `setup` method in callback " + f"`{callback.__class__}` to match the method signature" + " in `ray.tune.callback.Callback`.", + FutureWarning, + ) + callback.setup() + else: + raise e + + def on_step_begin(self, **info): + for callback in self._callbacks: + callback.on_step_begin(**info) + + def on_step_end(self, **info): + for callback in self._callbacks: + callback.on_step_end(**info) + + def on_trial_start(self, **info): + for callback in self._callbacks: + callback.on_trial_start(**info) + + def on_trial_restore(self, **info): + for callback in self._callbacks: + callback.on_trial_restore(**info) + + def on_trial_save(self, **info): + for callback in self._callbacks: + callback.on_trial_save(**info) + + def on_trial_result(self, **info): + for callback in self._callbacks: + callback.on_trial_result(**info) + + def on_trial_complete(self, **info): + for callback in self._callbacks: + callback.on_trial_complete(**info) + + def on_trial_recover(self, **info): + for callback in self._callbacks: + callback.on_trial_recover(**info) + + def on_trial_error(self, **info): + for callback in self._callbacks: + callback.on_trial_error(**info) + + def on_checkpoint(self, **info): + for callback in self._callbacks: + callback.on_checkpoint(**info) + + def on_experiment_end(self, **info): + for callback in self._callbacks: + callback.on_experiment_end(**info) + + def get_state(self) -> Optional[Dict]: + """Gets the state of all callbacks contained within this list. + If there are no stateful callbacks, then None will be returned in order + to avoid saving an unnecessary callback checkpoint file.""" + state = {} + any_stateful_callbacks = False + for i, callback in enumerate(self._callbacks): + callback_state = callback.get_state() + if callback_state: + any_stateful_callbacks = True + state[i] = callback_state + if not any_stateful_callbacks: + return None + return state + + def set_state(self, state: Dict): + """Sets the state for all callbacks contained within this list. + Skips setting state for all stateless callbacks where `get_state` + returned None.""" + for i, callback in enumerate(self._callbacks): + callback_state = state.get(i, None) + if callback_state: + callback.set_state(callback_state) + + def save_to_dir(self, checkpoint_dir: str, session_str: str = "default"): + """Save the state of the callback list to the checkpoint_dir. + + Args: + checkpoint_dir: directory where the checkpoint is stored. + session_str: Unique identifier of the current run session (ex: timestamp). + """ + state_dict = self.get_state() + + if state_dict: + file_name = self.CKPT_FILE_TMPL.format(session_str) + tmp_file_name = f".tmp-{file_name}" + _atomic_save( + state=state_dict, + checkpoint_dir=checkpoint_dir, + file_name=file_name, + tmp_file_name=tmp_file_name, + ) + + def restore_from_dir(self, checkpoint_dir: str): + """Restore the state of the list of callbacks from the checkpoint_dir. + + You should check if it's possible to restore with `can_restore` + before calling this method. + + Args: + checkpoint_dir: directory where the checkpoint is stored. + + Raises: + RuntimeError: if unable to find checkpoint. + NotImplementedError: if the `set_state` method is not implemented. + """ + state_dict = _load_newest_checkpoint( + checkpoint_dir, self.CKPT_FILE_TMPL.format("*") + ) + if not state_dict: + raise RuntimeError( + "Unable to find checkpoint in {}.".format(checkpoint_dir) + ) + self.set_state(state_dict) + + def can_restore(self, checkpoint_dir: str) -> bool: + """Check if the checkpoint_dir contains the saved state for this callback list. + + Returns: + can_restore: True if the checkpoint_dir contains a file of the + format `CKPT_FILE_TMPL`. False otherwise. + """ + return any( + glob.iglob(Path(checkpoint_dir, self.CKPT_FILE_TMPL.format("*")).as_posix()) + ) + + def __len__(self) -> int: + return len(self._callbacks) + + def __getitem__(self, i: int) -> "Callback": + return self._callbacks[i] diff --git a/.venv/lib/python3.11/site-packages/ray/tune/constants.py b/.venv/lib/python3.11/site-packages/ray/tune/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..36caebcf3aafcd70cadb9e0a883f2b39aa40d955 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/tune/constants.py @@ -0,0 +1,32 @@ +# ================================================== +# Environment Variables +# ================================================== + +# NOTE: When adding a new environment variable, please track it in this list. +TUNE_ENV_VARS = { + "RAY_AIR_LOCAL_CACHE_DIR", + "TUNE_DISABLE_AUTO_CALLBACK_LOGGERS", + "TUNE_DISABLE_AUTO_INIT", + "TUNE_DISABLE_DATED_SUBDIR", + "TUNE_DISABLE_STRICT_METRIC_CHECKING", + "TUNE_DISABLE_SIGINT_HANDLER", + "TUNE_FORCE_TRIAL_CLEANUP_S", + "TUNE_FUNCTION_THREAD_TIMEOUT_S", + "TUNE_GLOBAL_CHECKPOINT_S", + "TUNE_MAX_LEN_IDENTIFIER", + "TUNE_MAX_PENDING_TRIALS_PG", + "TUNE_PLACEMENT_GROUP_PREFIX", + "TUNE_PLACEMENT_GROUP_RECON_INTERVAL", + "TUNE_PRINT_ALL_TRIAL_ERRORS", + "TUNE_RESULT_DIR", + "TUNE_RESULT_BUFFER_LENGTH", + "TUNE_RESULT_DELIM", + "TUNE_RESULT_BUFFER_MAX_TIME_S", + "TUNE_RESULT_BUFFER_MIN_TIME_S", + "TUNE_WARN_THRESHOLD_S", + "TUNE_WARN_INSUFFICENT_RESOURCE_THRESHOLD_S", + "TUNE_WARN_INSUFFICENT_RESOURCE_THRESHOLD_S_AUTOSCALER", + "TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S", + "TUNE_STATE_REFRESH_PERIOD", + "TUNE_RESTORE_RETRY_NUM", +} diff --git a/.venv/lib/python3.11/site-packages/ray/tune/context.py b/.venv/lib/python3.11/site-packages/ray/tune/context.py new file mode 100644 index 0000000000000000000000000000000000000000..0575a2b7af125b82469b92c46a01a5624fd17b53 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/tune/context.py @@ -0,0 +1,113 @@ +import threading +from typing import Any, Dict, Optional + +from ray.train._internal import session +from ray.train.constants import _v2_migration_warnings_enabled +from ray.train.context import TrainContext as TrainV1Context +from ray.train.utils import _copy_doc +from ray.tune.execution.placement_groups import PlacementGroupFactory +from ray.util.annotations import Deprecated, PublicAPI + +# The context singleton on this process. +_tune_context: Optional["TuneContext"] = None +_tune_context_lock = threading.Lock() + + +_TRAIN_SPECIFIC_CONTEXT_DEPRECATION_MESSAGE = ( + "`{}` is deprecated for Ray Tune because there is no concept of worker ranks " + "for Ray Tune, so these methods only make sense to use in the context of " + "a Ray Train worker." +) + + +@PublicAPI(stability="beta") +class TuneContext(TrainV1Context): + """Context to access metadata within Ray Tune functions.""" + + # NOTE: These methods are deprecated on the TrainContext, but are still + # available on the TuneContext. Re-defining them here to avoid the + # deprecation warnings. + + @_copy_doc(session.get_trial_name) + def get_trial_name(self) -> str: + return session.get_trial_name() + + @_copy_doc(session.get_trial_id) + def get_trial_id(self) -> str: + return session.get_trial_id() + + @_copy_doc(session.get_trial_resources) + def get_trial_resources(self) -> PlacementGroupFactory: + return session.get_trial_resources() + + @_copy_doc(session.get_trial_dir) + def get_trial_dir(self) -> str: + return session.get_trial_dir() + + # Deprecated APIs + + @Deprecated + def get_metadata(self) -> Dict[str, Any]: + raise DeprecationWarning( + "`get_metadata` is deprecated for Ray Tune, as it has never been usable." + ) + + @Deprecated( + message=_TRAIN_SPECIFIC_CONTEXT_DEPRECATION_MESSAGE.format("get_world_size"), + warning=_v2_migration_warnings_enabled(), + ) + @_copy_doc(TrainV1Context.get_world_size) + def get_world_size(self) -> int: + return session.get_world_size() + + @Deprecated( + message=_TRAIN_SPECIFIC_CONTEXT_DEPRECATION_MESSAGE.format("get_world_rank"), + warning=_v2_migration_warnings_enabled(), + ) + @_copy_doc(TrainV1Context.get_world_rank) + def get_world_rank(self) -> int: + return session.get_world_rank() + + @Deprecated( + message=_TRAIN_SPECIFIC_CONTEXT_DEPRECATION_MESSAGE.format("get_local_rank"), + warning=_v2_migration_warnings_enabled(), + ) + @_copy_doc(TrainV1Context.get_local_rank) + def get_local_rank(self) -> int: + return session.get_local_rank() + + @Deprecated( + message=_TRAIN_SPECIFIC_CONTEXT_DEPRECATION_MESSAGE.format( + "get_local_world_size" + ), + warning=_v2_migration_warnings_enabled(), + ) + @_copy_doc(TrainV1Context.get_local_world_size) + def get_local_world_size(self) -> int: + return session.get_local_world_size() + + @Deprecated( + message=_TRAIN_SPECIFIC_CONTEXT_DEPRECATION_MESSAGE.format("get_node_rank"), + warning=_v2_migration_warnings_enabled(), + ) + @_copy_doc(TrainV1Context.get_node_rank) + def get_node_rank(self) -> int: + return session.get_node_rank() + + +@PublicAPI(stability="beta") +def get_context() -> TuneContext: + """Get or create a singleton Ray Tune context. + + The context is only available in a tune function passed to the `ray.tune.Tuner`. + + See the :class:`~ray.tune.TuneContext` API reference to see available methods. + """ + global _tune_context + + with _tune_context_lock: + if _tune_context is None: + # TODO(justinvyu): This default should be a dummy context + # that is only used for testing / running outside of Tune. + _tune_context = TuneContext() + return _tune_context diff --git a/.venv/lib/python3.11/site-packages/ray/tune/error.py b/.venv/lib/python3.11/site-packages/ray/tune/error.py new file mode 100644 index 0000000000000000000000000000000000000000..9f2b427a2788e09a4c5bf5bc2d208dce31ae2dfd --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/tune/error.py @@ -0,0 +1,48 @@ +from ray.util.annotations import PublicAPI + + +@PublicAPI +class TuneError(Exception): + """General error class raised by ray.tune.""" + + pass + + +class _AbortTrialExecution(TuneError): + """Error that indicates a trial should not be retried.""" + + pass + + +class _SubCategoryTuneError(TuneError): + """The more specific TuneError that happens for a certain Tune + subroutine. For example starting/stopping a trial. + """ + + def __init__(self, traceback_str: str): + self.traceback_str = traceback_str + + def __str__(self): + return self.traceback_str + + +class _TuneStopTrialError(_SubCategoryTuneError): + """Error that happens when stopping a tune trial.""" + + pass + + +class _TuneStartTrialError(_SubCategoryTuneError): + """Error that happens when starting a tune trial.""" + + pass + + +class _TuneNoNextExecutorEventError(_SubCategoryTuneError): + """Error that happens when waiting to get the next event to + handle from RayTrialExecutor. + + Note: RayTaskError will be raised by itself and will not be using + this category. This category is for everything else.""" + + pass diff --git a/.venv/lib/python3.11/site-packages/ray/tune/examples/cifar10_pytorch.py b/.venv/lib/python3.11/site-packages/ray/tune/examples/cifar10_pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..972773f77cb3225473074cdb3582b43ca41614aa --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/tune/examples/cifar10_pytorch.py @@ -0,0 +1,285 @@ +# ruff: noqa +# fmt: off + +# __import_begin__ +import os +import tempfile +from typing import Dict + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +import torchvision +import torchvision.transforms as transforms +from filelock import FileLock +from torch.utils.data import random_split + +import ray +from ray import train, tune +from ray.train import Checkpoint +from ray.tune.schedulers import ASHAScheduler + +# __import_end__ + + +# __load_data_begin__ +DATA_DIR = tempfile.mkdtemp() + +def load_data(data_dir): + transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), + ]) + + # We add FileLock here because multiple workers will want to + # download data, and this may cause overwrites since + # DataLoader is not threadsafe. + with FileLock(os.path.expanduser("~/.data.lock")): + trainset = torchvision.datasets.CIFAR10( + root=data_dir, train=True, download=True, transform=transform) + + testset = torchvision.datasets.CIFAR10( + root=data_dir, train=False, download=True, transform=transform) + + return trainset, testset +# __load_data_end__ + +def load_test_data(): + # Loads a fake dataset for testing so it doesn't rely on external download. + trainset = torchvision.datasets.FakeData( + 128, (3, 32, 32), num_classes=10, transform=transforms.ToTensor() + ) + testset = torchvision.datasets.FakeData( + 16, (3, 32, 32), num_classes=10, transform=transforms.ToTensor() + ) + return trainset, testset + + +# __net_begin__ +class Net(nn.Module): + def __init__(self, l1=120, l2=84): + super(Net, self).__init__() + self.conv1 = nn.Conv2d(3, 6, 5) + self.pool = nn.MaxPool2d(2, 2) + self.conv2 = nn.Conv2d(6, 16, 5) + self.fc1 = nn.Linear(16 * 5 * 5, l1) + self.fc2 = nn.Linear(l1, l2) + self.fc3 = nn.Linear(l2, 10) + + def forward(self, x): + x = self.pool(F.relu(self.conv1(x))) + x = self.pool(F.relu(self.conv2(x))) + x = x.view(-1, 16 * 5 * 5) + x = F.relu(self.fc1(x)) + x = F.relu(self.fc2(x)) + x = self.fc3(x) + return x +# __net_end__ + + +# __train_begin__ +def train_cifar(config): + net = Net(config["l1"], config["l2"]) + + device = "cpu" + if torch.cuda.is_available(): + device = "cuda:0" + if torch.cuda.device_count() > 1: + net = nn.DataParallel(net) + net.to(device) + + criterion = nn.CrossEntropyLoss() + optimizer = optim.SGD(net.parameters(), lr=config["lr"], momentum=0.9) + + # Load existing checkpoint through `get_checkpoint()` API. + if train.get_checkpoint(): + loaded_checkpoint = train.get_checkpoint() + with loaded_checkpoint.as_directory() as loaded_checkpoint_dir: + model_state, optimizer_state = torch.load( + os.path.join(loaded_checkpoint_dir, "checkpoint.pt") + ) + net.load_state_dict(model_state) + optimizer.load_state_dict(optimizer_state) + + if config["smoke_test"]: + trainset, testset = load_test_data() + else: + trainset, testset = load_data(DATA_DIR) + + test_abs = int(len(trainset) * 0.8) + train_subset, val_subset = random_split( + trainset, [test_abs, len(trainset) - test_abs]) + + trainloader = torch.utils.data.DataLoader( + train_subset, + batch_size=int(config["batch_size"]), + shuffle=True, + num_workers=0 if config["smoke_test"] else 8, + ) + valloader = torch.utils.data.DataLoader( + val_subset, + batch_size=int(config["batch_size"]), + shuffle=True, + num_workers=0 if config["smoke_test"] else 8, + ) + + for epoch in range(10): # loop over the dataset multiple times + running_loss = 0.0 + epoch_steps = 0 + for i, data in enumerate(trainloader): + # get the inputs; data is a list of [inputs, labels] + inputs, labels = data + inputs, labels = inputs.to(device), labels.to(device) + + # zero the parameter gradients + optimizer.zero_grad() + + # forward + backward + optimize + outputs = net(inputs) + loss = criterion(outputs, labels) + loss.backward() + optimizer.step() + + # print statistics + running_loss += loss.item() + epoch_steps += 1 + if i % 2000 == 1999: # print every 2000 mini-batches + print("[%d, %5d] loss: %.3f" % (epoch + 1, i + 1, + running_loss / epoch_steps)) + running_loss = 0.0 + + # Validation loss + val_loss = 0.0 + val_steps = 0 + total = 0 + correct = 0 + for i, data in enumerate(valloader, 0): + with torch.no_grad(): + inputs, labels = data + inputs, labels = inputs.to(device), labels.to(device) + + outputs = net(inputs) + _, predicted = torch.max(outputs.data, 1) + total += labels.size(0) + correct += (predicted == labels).sum().item() + + loss = criterion(outputs, labels) + val_loss += loss.cpu().numpy() + val_steps += 1 + + # Here we save a checkpoint. It is automatically registered with + # Ray Tune and will potentially be accessed through in ``get_checkpoint()`` + # in future iterations. + # Note to save a file like checkpoint, you still need to put it under a directory + # to construct a checkpoint. + with tempfile.TemporaryDirectory() as temp_checkpoint_dir: + path = os.path.join(temp_checkpoint_dir, "checkpoint.pt") + torch.save( + (net.state_dict(), optimizer.state_dict()), path + ) + checkpoint = Checkpoint.from_directory(temp_checkpoint_dir) + train.report( + {"loss": (val_loss / val_steps), "accuracy": correct / total}, + checkpoint=checkpoint, + ) + print("Finished Training") +# __train_end__ + + +# __test_acc_begin__ +def test_best_model(config: Dict, checkpoint: "Checkpoint", smoke_test=False): + best_trained_model = Net(config["l1"], config["l2"]) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + best_trained_model.to(device) + + with checkpoint.as_directory() as checkpoint_dir: + checkpoint_path = os.path.join(checkpoint_dir, "checkpoint.pt") + model_state, optimizer_state = torch.load(checkpoint_path) + best_trained_model.load_state_dict(model_state) + + if smoke_test: + _, testset = load_test_data() + else: + _, testset = load_data(DATA_DIR) + + testloader = torch.utils.data.DataLoader( + testset, batch_size=4, shuffle=False, num_workers=2) + + correct = 0 + total = 0 + with torch.no_grad(): + for data in testloader: + images, labels = data + images, labels = images.to(device), labels.to(device) + outputs = best_trained_model(images) + _, predicted = torch.max(outputs.data, 1) + total += labels.size(0) + correct += (predicted == labels).sum().item() + + + print("Best trial test set accuracy: {}".format(correct / total)) + +# __test_acc_end__ + +# __main_begin__ +def main(num_samples=10, max_num_epochs=10, gpus_per_trial=2, smoke_test=False): + config = { + "l1": tune.sample_from(lambda _: 2 ** np.random.randint(2, 9)), + "l2": tune.sample_from(lambda _: 2 ** np.random.randint(2, 9)), + "lr": tune.loguniform(1e-4, 1e-1), + "batch_size": tune.choice([2, 4, 8, 16]), + "smoke_test": smoke_test, + } + scheduler = ASHAScheduler( + max_t=max_num_epochs, + grace_period=1, + reduction_factor=2) + + tuner = tune.Tuner( + tune.with_resources( + tune.with_parameters(train_cifar), + resources={"cpu": 2, "gpu": gpus_per_trial}, + ), + tune_config=tune.TuneConfig( + metric="loss", + mode="min", + num_samples=num_samples, + scheduler=scheduler + ), + param_space=config, + ) + results = tuner.fit() + best_result = results.get_best_result("loss", "min") + print("Best trial config: {}".format(best_result.config)) + print("Best trial final validation loss: {}".format( + best_result.metrics["loss"])) + print("Best trial final validation accuracy: {}".format( + best_result.metrics["accuracy"])) + + test_best_model(best_result.config, best_result.checkpoint, smoke_test=smoke_test) + + +# __main_end__ + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument( + "--smoke-test", action="store_true", help="Finish quickly for testing") + parser.add_argument( + "--ray-address", + help="Address of Ray cluster for seamless distributed execution.", + required=False) + args, _ = parser.parse_known_args() + + if args.smoke_test: + ray.init(num_cpus=2) + main(num_samples=1, max_num_epochs=1, gpus_per_trial=0, smoke_test=True) + else: + ray.init(args.ray_address) + # Change this to activate training on GPUs + main(num_samples=10, max_num_epochs=10, gpus_per_trial=0) diff --git a/.venv/lib/python3.11/site-packages/ray/tune/examples/lightgbm_example.py b/.venv/lib/python3.11/site-packages/ray/tune/examples/lightgbm_example.py new file mode 100644 index 0000000000000000000000000000000000000000..3db060e86ec695680791cdc63b599b60d3719035 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/tune/examples/lightgbm_example.py @@ -0,0 +1,105 @@ +import lightgbm as lgb +import sklearn.datasets +import sklearn.metrics +from sklearn.model_selection import train_test_split + +from ray import tune +from ray.tune.integration.lightgbm import TuneReportCheckpointCallback +from ray.tune.schedulers import ASHAScheduler + + +def train_breast_cancer(config: dict): + # This is a simple training function to be passed into Tune + + # Load dataset + data, target = sklearn.datasets.load_breast_cancer(return_X_y=True) + + # Split into train and test set + train_x, test_x, train_y, test_y = train_test_split(data, target, test_size=0.25) + + # Build input Datasets for LightGBM + train_set = lgb.Dataset(train_x, label=train_y) + test_set = lgb.Dataset(test_x, label=test_y) + + # Train the classifier, using the Tune callback + lgb.train( + config, + train_set, + valid_sets=[test_set], + valid_names=["eval"], + verbose_eval=False, + callbacks=[ + TuneReportCheckpointCallback( + { + "binary_error": "eval-binary_error", + "binary_logloss": "eval-binary_logloss", + } + ) + ], + ) + + +def train_breast_cancer_cv(config: dict): + # This is a simple training function to be passed into Tune, using + # lightgbm's cross validation functionality + + # Load dataset + data, target = sklearn.datasets.load_breast_cancer(return_X_y=True) + + train_set = lgb.Dataset(data, label=target) + + # Run CV, using the Tune callback + lgb.cv( + config, + train_set, + verbose_eval=False, + stratified=True, + # Checkpointing is not supported for CV + # LightGBM aggregates metrics over folds automatically + # with the cv_agg key. Both mean and standard deviation + # are provided. + callbacks=[ + TuneReportCheckpointCallback( + { + "binary_error": "cv_agg-binary_error-mean", + "binary_logloss": "cv_agg-binary_logloss-mean", + "binary_error_stdv": "cv_agg-binary_error-stdv", + "binary_logloss_stdv": "cv_agg-binary_logloss-stdv", + }, + frequency=0, + ) + ], + ) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument( + "--use-cv", action="store_true", help="Use `lgb.cv` instead of `lgb.train`." + ) + args, _ = parser.parse_known_args() + + config = { + "objective": "binary", + "metric": ["binary_error", "binary_logloss"], + "verbose": -1, + "boosting_type": tune.grid_search(["gbdt", "dart"]), + "num_leaves": tune.randint(10, 1000), + "learning_rate": tune.loguniform(1e-8, 1e-1), + } + + tuner = tune.Tuner( + train_breast_cancer if not args.use_cv else train_breast_cancer_cv, + tune_config=tune.TuneConfig( + metric="binary_error", + mode="min", + num_samples=2, + scheduler=ASHAScheduler(), + ), + param_space=config, + ) + results = tuner.fit() + + print("Best hyperparameters found were: ", results.get_best_result().config) diff --git a/.venv/lib/python3.11/site-packages/ray/tune/execution/__init__.py b/.venv/lib/python3.11/site-packages/ray/tune/execution/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/ray/tune/execution/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/tune/execution/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..309e7cc67e2f9816811dbd03212ab39a720997b4 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/tune/execution/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/tune/execution/__pycache__/class_cache.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/tune/execution/__pycache__/class_cache.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c59115f1a56afee81731a4f8303df4bc36305193 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/tune/execution/__pycache__/class_cache.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/tune/execution/__pycache__/cluster_info.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/tune/execution/__pycache__/cluster_info.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b30bde22b6b4bc71628794c62285b2239fe18455 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/tune/execution/__pycache__/cluster_info.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/tune/execution/__pycache__/experiment_state.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/tune/execution/__pycache__/experiment_state.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0a30efa8d301235f1170e9b412494212956e66af Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/tune/execution/__pycache__/experiment_state.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/tune/execution/__pycache__/insufficient_resources_manager.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/tune/execution/__pycache__/insufficient_resources_manager.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8e02e452fcd00ac714485c83b957d53c3f177b8b Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/tune/execution/__pycache__/insufficient_resources_manager.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/tune/execution/__pycache__/placement_groups.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/tune/execution/__pycache__/placement_groups.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..42d91252441b513feb1de11d9f2b1bf952f1aa39 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/tune/execution/__pycache__/placement_groups.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/tune/execution/class_cache.py b/.venv/lib/python3.11/site-packages/ray/tune/execution/class_cache.py new file mode 100644 index 0000000000000000000000000000000000000000..94c4b5148a4db18ce506515469cc5272595f2af3 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/tune/execution/class_cache.py @@ -0,0 +1,68 @@ +import os + +import ray +from ray.air.constants import COPY_DIRECTORY_CHECKPOINTS_INSTEAD_OF_MOVING_ENV +from ray.train.constants import ( + ENABLE_V2_MIGRATION_WARNINGS_ENV_VAR, + RAY_CHDIR_TO_TRIAL_DIR, +) +from ray.train.v2._internal.constants import ( + ENV_VARS_TO_PROPAGATE as TRAIN_ENV_VARS_TO_PROPAGATE, +) + +DEFAULT_ENV_VARS = { + # https://github.com/ray-project/ray/issues/28197 + "PL_DISABLE_FORK": "1" +} +ENV_VARS_TO_PROPAGATE = ( + { + COPY_DIRECTORY_CHECKPOINTS_INSTEAD_OF_MOVING_ENV, + RAY_CHDIR_TO_TRIAL_DIR, + ENABLE_V2_MIGRATION_WARNINGS_ENV_VAR, + "AWS_ACCESS_KEY_ID", + "AWS_SECRET_ACCESS_KEY", + "AWS_SECURITY_TOKEN", + "AWS_SESSION_TOKEN", + } + # Propagate the Ray Train environment variables from the driver process + # to the trainable process so that Tune + Train v2 can be used together. + | TRAIN_ENV_VARS_TO_PROPAGATE +) + + +class _ActorClassCache: + """Caches actor classes. + + ray.remote is a registration call. It sends the serialized object to the + key value store (redis), and will be fetched at an arbitrary worker + later. Registration does not use any Ray scheduling resources. + + Later, class.remote() actually creates the remote actor. The + actor will be instantiated on some arbitrary machine, + according to the underlying Ray scheduler. + + Without this cache, you would register the same serialized object + over and over again. Naturally, since redis doesn’t spill to disk, + this can easily nuke the redis instance (and basically blow up Ray). + This cache instead allows us to register once and only once. + + Note that we assume there can be multiple trainables in the + system at once. + """ + + def __init__(self): + self._cache = {} + + def get(self, trainable_cls): + """Gets the wrapped trainable_cls, otherwise calls ray.remote.""" + env_vars = DEFAULT_ENV_VARS.copy() + + for env_var_to_propagate in ENV_VARS_TO_PROPAGATE: + if env_var_to_propagate in os.environ: + env_vars[env_var_to_propagate] = os.environ[env_var_to_propagate] + + runtime_env = {"env_vars": env_vars} + if trainable_cls not in self._cache: + remote_cls = ray.remote(runtime_env=runtime_env)(trainable_cls) + self._cache[trainable_cls] = remote_cls + return self._cache[trainable_cls] diff --git a/.venv/lib/python3.11/site-packages/ray/tune/execution/cluster_info.py b/.venv/lib/python3.11/site-packages/ray/tune/execution/cluster_info.py new file mode 100644 index 0000000000000000000000000000000000000000..23bb9d03dbd04a7ee8ff99ee07a1f455ef95ac9f --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/tune/execution/cluster_info.py @@ -0,0 +1,12 @@ +from functools import lru_cache +from pathlib import Path + + +@lru_cache() +def _is_ray_cluster(): + """Checks if the bootstrap config file exists. + + This will always exist if using an autoscaling cluster/started + with the ray cluster launcher. + """ + return Path("~/ray_bootstrap_config.yaml").expanduser().exists() diff --git a/.venv/lib/python3.11/site-packages/ray/tune/execution/experiment_state.py b/.venv/lib/python3.11/site-packages/ray/tune/execution/experiment_state.py new file mode 100644 index 0000000000000000000000000000000000000000..8afda2a3a2751b6bb6f44f4e59a6bc24c89b81d3 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/tune/execution/experiment_state.py @@ -0,0 +1,287 @@ +import fnmatch +import logging +import os +import time +from collections import Counter +from pathlib import Path +from typing import Callable, Dict, Optional, Union + +import pyarrow.fs + +from ray.train._internal.storage import ( + StorageContext, + _download_from_fs_path, + _list_at_fs_path, + get_fs_and_path, +) +from ray.tune.experiment.trial import Trial +from ray.tune.impl.out_of_band_serialize_dataset import out_of_band_serialize_dataset + +logger = logging.getLogger(__name__) + + +_SLOW_SYNC_WARNING = ( + "This could be due to a large number of trials, " + "large logfiles from lots of reported metrics, or throttling from the " + "remote storage if uploading too frequently.\n" + "You may want to consider switching the `RunConfig(storage_filesystem)`" + " to a more performant storage backend such as s3fs for a " + "S3 storage path.\n" + "You can suppress this error by setting the environment variable " + "TUNE_WARN_SLOW_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S to a higher " + "value than the current threshold ({threshold})." +) + + +def _find_newest_experiment_checkpoint( + experiment_path: str, fs: Optional[pyarrow.fs.FileSystem] = None +) -> Optional[str]: + """Returns file name of most recently created experiment checkpoint. + + Args: + experiment_path: Local or remote path to the experiment directory + containing at least one experiment checkpoint file. + + Returns: + str: The local or remote path to the latest experiment checkpoint file + based on timestamp. None if no experiment checkpoints were found. + """ + from ray.tune.execution.tune_controller import TuneController + + fs, experiment_fs_path = get_fs_and_path(experiment_path, storage_filesystem=fs) + filenames = _list_at_fs_path(fs=fs, fs_path=experiment_fs_path) + pattern = TuneController.CKPT_FILE_TMPL.format("*") + matching = fnmatch.filter(filenames, pattern) + if not matching: + return None + filename = max(matching) + return Path(experiment_fs_path, filename).as_posix() + + +class _ExperimentCheckpointManager: + """Helper class for managing experiment-level checkpoints. + + This class implements the ``checkpoint()`` method used to checkpoint + experiment state. When called, this will serialize and write to disk + the state of the trial runner, trial executor, and search algorithm, to + a specified checkpoint file. + + The checkpoint period is automatically adjusted to + ``max(10, time_per_checkpoint * 19)``. This means that at most 5% of the + time (1/20) will be used for writing checkpoints, while 95% of the time + (19/20) will be used to handle the rest of the training loop. + """ + + def __init__( + self, + *, + storage: Optional[StorageContext], + checkpoint_period: Union[int, float, str], + sync_every_n_trial_checkpoints: Optional[int] = None, + ): + self._storage = storage + + self._last_save_time = float("-inf") + self._last_sync_time = None + + # Dynamic checkpointing period + self._auto_checkpoint_enabled = checkpoint_period == "auto" + if self._auto_checkpoint_enabled: + self._checkpoint_period = 10.0 # Initial value + else: + self._checkpoint_period = float(checkpoint_period) + + # TODO(justinvyu): This is a non-performant workaround to force sync + # every num_to_keep checkpoints in order to maintain consistency + # between the experiment state's view of the latest checkpoint, + # and the actual latest checkpoint that was uploaded. + self._sync_every_n_trial_checkpoints = sync_every_n_trial_checkpoints + self._trial_num_checkpoints_since_last_sync: Dict[Trial, int] = Counter() + self._should_force_sync_up: bool = False + + self._excessive_sync_threshold = float( + os.environ.get( + "TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S", "5" + ) + ) + self._slow_sync_threshold = float( + os.environ.get( + "TUNE_WARN_SLOW_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S", "30" + ) + ) + + @property + def auto_checkpoint_enabled(self): + return self._auto_checkpoint_enabled + + def _update_auto_checkpoint_time(self, time_taken: float): + if self._auto_checkpoint_enabled: + # Multiplying this time by 19 means we spend ~5% of the time + # writing global checkpoints and 95% of the time processing trials + self._checkpoint_period = max(10.0, time_taken * 19) + logger.debug( + f"Experiment state snapshotting took " + f"{time_taken:.2f} seconds. " + f"Adjusting snapshotting period to " + f"{self._checkpoint_period:.2f} seconds." + ) + + def sync_up_experiment_state( + self, + save_fn: Callable[[], None], + force: bool = False, + wait: bool = False, + ): + """Saves execution state to the experiment directory on the storage path. + This includes an experiment checkpoint file that contains trial statuses + and the searcher state. + + Overwrites the current session checkpoint, which starts when self + is instantiated. Throttle depends on self._checkpoint_period. + + Args: + save_fn: Function to call to actually save data to the driver + staging path. The files in the driver staging path will be + uploaded to the storage path. + force: Forces an experiment checkpoint and launches a sync to storage. + This happens regardless of checkpoint_period + wait: Waits for the sync up to complete before returning. + """ + driver_staging_path = self._storage.experiment_driver_staging_path + + force = force or self._should_force_sync_up + + now = time.monotonic() + if now - self._last_save_time < self._checkpoint_period and not force: + return + + # Checkpoint + checkpoint_time_start = time.monotonic() + + # NOTE: This context manager is for Datasets captured in a trial config. + # This is the case when *tuning over datasets*. + # If the datasets have already been full executed, then serializing + # block refs means that this checkpoint is not usable in a new Ray cluster. + # This context will serialize the dataset execution plan instead, if available. + with out_of_band_serialize_dataset(): + save_fn() + + def wait_for_sync(): + try: + self._storage.syncer.wait() + except Exception: + logger.error( + "Saving experiment state to storage at " + f"'{self._storage.experiment_fs_path}' failed with exception: ", + exc_info=True, + ) + + if force: + start_time = time.monotonic() + wait_for_sync() + wait_time = time.monotonic() - start_time + if wait_time > self._slow_sync_threshold: + logger.warning( + "Saving the experiment state (which holds a global view " + "of trial statuses and is used to restore the experiment) " + f"took ~{wait_time:.2f} seconds, which may be a performance " + "bottleneck.\n" + f"{_SLOW_SYNC_WARNING.format(threshold=self._slow_sync_threshold)}" + ) + + time_since_last_sync = ( + time.monotonic() - self._last_sync_time + if self._last_sync_time is not None + else None + ) + launched_sync = self._storage.syncer.sync_up( + driver_staging_path, self._storage.experiment_fs_path + ) + if launched_sync: + if ( + time_since_last_sync is not None + and time_since_last_sync < self._excessive_sync_threshold + and self._should_force_sync_up + ): + logger.warning( + "Experiment state snapshotting has been triggered multiple " + f"times in the last {self._excessive_sync_threshold} seconds " + "and may become a bottleneck. " + "A snapshot is forced if `CheckpointConfig(num_to_keep)` is set, " + "and a trial has checkpointed >= `num_to_keep` times " + "since the last snapshot.\n" + "You may want to consider increasing the " + "`CheckpointConfig(num_to_keep)` or decreasing the frequency of " + "saving checkpoints.\n" + "You can suppress this warning by setting the environment variable " + "TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S " + "to a smaller value than the current threshold " + f"({self._excessive_sync_threshold}). " + "Set it to 0 to completely suppress this warning." + ) + + self._last_sync_time = time.monotonic() + + # We just synced, so reset the force flag + self._trial_num_checkpoints_since_last_sync.clear() + self._should_force_sync_up = False + else: + if ( + time_since_last_sync is not None + and time_since_last_sync > self._slow_sync_threshold + ): + logger.warning( + "Saving the experiment state (which holds a global view " + "of trial statuses and is used to restore the experiment) " + f"has already taken {time_since_last_sync:.2f} seconds, " + "which may cause consistency issues upon restoration if your " + "driver script ungracefully exits.\n" + f"{_SLOW_SYNC_WARNING.format(threshold=self._slow_sync_threshold)}" + ) + + if wait: + wait_for_sync() + + checkpoint_time_taken = time.monotonic() - checkpoint_time_start + + # Adjust dynamic checkpointing + self._update_auto_checkpoint_time(time_taken=checkpoint_time_taken) + + # Finish + self._last_save_time = time.monotonic() + + def sync_down_experiment_state(self) -> None: + fs = self._storage.storage_filesystem + filepaths = _list_at_fs_path(fs=fs, fs_path=self._storage.experiment_fs_path) + # TODO(ekl) we should refactor our restore code to read the necessary data + # directly from the storage context. As a temporary hack, restore all the + # serialized files from the root dir where other modules expect them to be. + matches = [ + path + for path in filepaths + if path.endswith(".json") or path.endswith(".pkl") + ] + for relpath in matches: + fs_path = Path(self._storage.experiment_fs_path, relpath).as_posix() + local_path = Path( + self._storage.experiment_driver_staging_path, relpath + ).as_posix() + _download_from_fs_path(fs=fs, fs_path=fs_path, local_path=local_path) + logger.debug( + f"Copied {matches} from:\n(fs, path) = " + f"({self._storage.storage_filesystem.type_name}, " + f"{self._storage.experiment_fs_path})\n" + f"-> {self._storage.experiment_driver_staging_path}" + ) + + def on_trial_checkpoint(self, trial: Trial): + if not self._sync_every_n_trial_checkpoints: + return + + self._trial_num_checkpoints_since_last_sync[trial] += 1 + + if ( + self._trial_num_checkpoints_since_last_sync[trial] + >= self._sync_every_n_trial_checkpoints + ): + self._should_force_sync_up = True diff --git a/.venv/lib/python3.11/site-packages/ray/tune/execution/insufficient_resources_manager.py b/.venv/lib/python3.11/site-packages/ray/tune/execution/insufficient_resources_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..0f788d2df7df64c3f34c12fef8b39d68ec758e66 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/tune/execution/insufficient_resources_manager.py @@ -0,0 +1,167 @@ +import logging +import os +import time +from functools import lru_cache +from typing import Dict, Optional, Tuple + +import ray +from ray.tune.execution.cluster_info import _is_ray_cluster +from ray.tune.experiment import Trial + +logger = logging.getLogger(__name__) + + +# Ideally we want to use @cache; but it's only available for python 3.9. +# Caching is only helpful/correct for no autoscaler case. +@lru_cache() +def _get_cluster_resources_no_autoscaler() -> Dict: + return ray.cluster_resources() + + +def _get_trial_cpu_and_gpu(trial: Trial) -> Tuple[int, int]: + cpu = trial.placement_group_factory.required_resources.get("CPU", 0) + gpu = trial.placement_group_factory.required_resources.get("GPU", 0) + return cpu, gpu + + +def _can_fulfill_no_autoscaler(trial: Trial) -> bool: + """Calculates if there is enough resources for a PENDING trial. + + For no autoscaler case. + """ + assert trial.status == Trial.PENDING + asked_cpus, asked_gpus = _get_trial_cpu_and_gpu(trial) + + return asked_cpus <= _get_cluster_resources_no_autoscaler().get( + "CPU", 0 + ) and asked_gpus <= _get_cluster_resources_no_autoscaler().get("GPU", 0) + + +@lru_cache() +def _get_insufficient_resources_warning_threshold() -> float: + if _is_ray_cluster(): + return float( + os.environ.get( + "TUNE_WARN_INSUFFICENT_RESOURCE_THRESHOLD_S_AUTOSCALER", "60" + ) + ) + else: + # Set the default to 10s so that we don't prematurely determine that + # a cluster cannot fulfill the resources requirements. + # TODO(xwjiang): Change it back once #18608 is resolved. + return float(os.environ.get("TUNE_WARN_INSUFFICENT_RESOURCE_THRESHOLD_S", "60")) + + +MSG_TRAIN_START = ( + "Training has not started in the last {wait_time:.0f} seconds. " + "This could be due to the cluster not having enough resources available. " +) +MSG_TRAIN_INSUFFICIENT = ( + "You asked for {asked_cpus} CPUs and {asked_gpus} GPUs, but the cluster only " + "has {cluster_cpus} CPUs and {cluster_gpus} GPUs available. " +) +MSG_TRAIN_END = ( + "Stop the training and adjust the required resources (e.g. via the " + "`ScalingConfig` or `resources_per_trial`, or `num_workers` for rllib), " + "or add more resources to your cluster." +) + +MSG_TUNE_START = ( + "No trial is running and no new trial has been started within " + "the last {wait_time:.0f} seconds. " + "This could be due to the cluster not having enough resources available. " +) +MSG_TUNE_INSUFFICIENT = ( + "You asked for {asked_cpus} CPUs and {asked_gpus} GPUs per trial, " + "but the cluster only has {cluster_cpus} CPUs and {cluster_gpus} GPUs available. " +) +MSG_TUNE_END = ( + "Stop the tuning and adjust the required resources (e.g. via the " + "`ScalingConfig` or `resources_per_trial`, or `num_workers` for rllib), " + "or add more resources to your cluster." +) + + +# TODO(xwjiang): Consider having a help page with more detailed instructions. +@lru_cache() +def _get_insufficient_resources_warning_msg( + for_train: bool = False, trial: Optional[Trial] = None +) -> str: + msg = "Ignore this message if the cluster is autoscaling. " + + if for_train: + start = MSG_TRAIN_START + insufficient = MSG_TRAIN_INSUFFICIENT + end = MSG_TRAIN_END + else: + start = MSG_TUNE_START + insufficient = MSG_TUNE_INSUFFICIENT + end = MSG_TUNE_END + + msg += start.format(wait_time=_get_insufficient_resources_warning_threshold()) + + if trial: + asked_cpus, asked_gpus = _get_trial_cpu_and_gpu(trial) + cluster_resources = _get_cluster_resources_no_autoscaler() + + msg += insufficient.format( + asked_cpus=asked_cpus, + asked_gpus=asked_gpus, + cluster_cpus=cluster_resources.get("CPU", 0), + cluster_gpus=cluster_resources.get("GPU", 0), + ) + + msg += end + + return msg + + +class _InsufficientResourcesManager: + """Insufficient resources manager. + + Makes best effort, conservative guesses about if Tune loop is stuck due to + infeasible resources. If so, outputs usability messages for users to + act upon. + """ + + def __init__(self, for_train: bool = False): + # The information tracked across the life time of Tune loop. + self._no_running_trials_since = -1 + self._last_trial_num = -1 + self._for_train = for_train + + def on_no_available_trials(self, all_trials): + """Tracks information across the life of Tune loop and makes guesses + about if Tune loop is stuck due to infeasible resources. + If so, outputs certain warning messages. + The logic should be conservative, non-intrusive and informative. + For example, rate limiting is applied so that the message is not + spammy. + """ + # This is approximately saying we are not making progress. + if len(all_trials) == self._last_trial_num: + if self._no_running_trials_since == -1: + self._no_running_trials_since = time.monotonic() + elif ( + time.monotonic() - self._no_running_trials_since + > _get_insufficient_resources_warning_threshold() + ): + can_fulfill_any = any( + trial.status == Trial.PENDING and _can_fulfill_no_autoscaler(trial) + for trial in all_trials + ) + + if can_fulfill_any: + # If one trial can be fulfilled, it will be fulfilled eventually + self._no_running_trials_since = -1 + return + + # Otherwise, can fulfill none + msg = _get_insufficient_resources_warning_msg( + for_train=self._for_train, trial=all_trials[0] + ) + logger.warning(msg) + self._no_running_trials_since = time.monotonic() + else: + self._no_running_trials_since = -1 + self._last_trial_num = len(all_trials) diff --git a/.venv/lib/python3.11/site-packages/ray/tune/execution/placement_groups.py b/.venv/lib/python3.11/site-packages/ray/tune/execution/placement_groups.py new file mode 100644 index 0000000000000000000000000000000000000000..0848b147878d752ac83d29b751cc7aa88906b71d --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/tune/execution/placement_groups.py @@ -0,0 +1,131 @@ +import warnings +from typing import Dict, Optional + +from ray.air.execution.resources.request import ResourceRequest +from ray.util.annotations import DeveloperAPI, PublicAPI +from ray.util.placement_group import placement_group + + +@PublicAPI(stability="beta") +class PlacementGroupFactory(ResourceRequest): + """Wrapper class that creates placement groups for trials. + + This function should be used to define resource requests for Ray Tune + trials. It holds the parameters to create + :ref:`placement groups `. + At a minimum, this will hold at least one bundle specifying the + resource requirements for each trial: + + .. code-block:: python + + from ray import tune + + tuner = tune.Tuner( + tune.with_resources( + train, + resources=tune.PlacementGroupFactory([ + {"CPU": 1, "GPU": 0.5, "custom_resource": 2} + ]) + ) + ) + tuner.fit() + + If the trial itself schedules further remote workers, the resource + requirements should be specified in additional bundles. You can also + pass the placement strategy for these bundles, e.g. to enforce + co-located placement: + + .. code-block:: python + + from ray import tune + + tuner = tune.Tuner( + tune.with_resources( + train, + resources=tune.PlacementGroupFactory([ + {"CPU": 1, "GPU": 0.5, "custom_resource": 2}, + {"CPU": 2}, + {"CPU": 2}, + ], strategy="PACK") + ) + ) + tuner.fit() + + The example above will reserve 1 CPU, 0.5 GPUs and 2 custom_resources + for the trainable itself, and reserve another 2 bundles of 2 CPUs each. + The trial will only start when all these resources are available. This + could be used e.g. if you had one learner running in the main trainable + that schedules two remote workers that need access to 2 CPUs each. + + If the trainable itself doesn't require resources. + You can specify it as: + + .. code-block:: python + + from ray import tune + + tuner = tune.Tuner( + tune.with_resources( + train, + resources=tune.PlacementGroupFactory([ + {}, + {"CPU": 2}, + {"CPU": 2}, + ], strategy="PACK") + ) + ) + tuner.fit() + + Args: + bundles: A list of bundles which + represent the resources requirements. + strategy: The strategy to create the placement group. + + - "PACK": Packs Bundles into as few nodes as possible. + - "SPREAD": Places Bundles across distinct nodes as even as possible. + - "STRICT_PACK": Packs Bundles into one node. The group is + not allowed to span multiple nodes. + - "STRICT_SPREAD": Packs Bundles across distinct nodes. + *args: Passed to the call of ``placement_group()`` + **kwargs: Passed to the call of ``placement_group()`` + + """ + + def __call__(self, *args, **kwargs): + warnings.warn( + "Calling PlacementGroupFactory objects is deprecated. Use " + "`to_placement_group()` instead.", + DeprecationWarning, + ) + kwargs.update(self._bound.kwargs) + # Call with bounded *args and **kwargs + return placement_group(*self._bound.args, **kwargs) + + +@DeveloperAPI +def resource_dict_to_pg_factory(spec: Optional[Dict[str, float]] = None): + """Translates resource dict into PlacementGroupFactory.""" + spec = spec or {"cpu": 1} + + spec = spec.copy() + + cpus = spec.pop("cpu", spec.pop("CPU", 0.0)) + gpus = spec.pop("gpu", spec.pop("GPU", 0.0)) + memory = spec.pop("memory", 0.0) + + # If there is a custom_resources key, use as base for bundle + bundle = {k: v for k, v in spec.pop("custom_resources", {}).items()} + + # Otherwise, consider all other keys as custom resources + if not bundle: + bundle = spec + + bundle.update( + { + "CPU": cpus, + "GPU": gpus, + "memory": memory, + } + ) + + return PlacementGroupFactory([bundle]) diff --git a/.venv/lib/python3.11/site-packages/ray/tune/execution/tune_controller.py b/.venv/lib/python3.11/site-packages/ray/tune/execution/tune_controller.py new file mode 100644 index 0000000000000000000000000000000000000000..bb482a80e6e05f7f036f5339a68b20e9336be23b --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/tune/execution/tune_controller.py @@ -0,0 +1,2181 @@ +import copy +import json +import logging +import os +import time +import traceback +import warnings +from collections import defaultdict, deque +from datetime import datetime +from functools import partial +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union + +import ray +from ray.air import ResourceRequest +from ray.air.constants import TIME_THIS_ITER_S +from ray.air.execution import PlacementGroupResourceManager, ResourceManager +from ray.air.execution._internal import RayActorManager, TrackedActor +from ray.exceptions import RayActorError, RayTaskError +from ray.train import CheckpointConfig +from ray.train._internal.session import _FutureTrainingResult, _TrainingResult +from ray.train._internal.storage import StorageContext +from ray.tune.callback import Callback, CallbackList +from ray.tune.error import TuneError, _AbortTrialExecution, _TuneStopTrialError +from ray.tune.execution.class_cache import _ActorClassCache +from ray.tune.execution.experiment_state import ( + _ExperimentCheckpointManager, + _find_newest_experiment_checkpoint, +) +from ray.tune.execution.insufficient_resources_manager import ( + _InsufficientResourcesManager, +) +from ray.tune.execution.placement_groups import PlacementGroupFactory +from ray.tune.experiment import Experiment, Trial +from ray.tune.experiment.trial import ( + _change_working_directory, + _get_trainable_kwargs, + _Location, + _noop_logger_creator, + _TrialInfo, +) +from ray.tune.result import ( + DEBUG_METRICS, + DEFAULT_METRIC, + DONE, + RESULT_DUPLICATE, + SHOULD_CHECKPOINT, + STDERR_FILE, + STDOUT_FILE, + TRIAL_INFO, +) +from ray.tune.schedulers import FIFOScheduler, TrialScheduler +from ray.tune.search import BasicVariantGenerator, SearchAlgorithm +from ray.tune.stopper import NoopStopper, Stopper +from ray.tune.tune_config import ResumeConfig +from ray.tune.utils import flatten_dict, warn_if_slow +from ray.tune.utils.log import Verbosity, _dedup_logs, has_verbosity +from ray.tune.utils.object_cache import _ObjectCache +from ray.tune.utils.resource_updater import _ResourceUpdater +from ray.tune.utils.serialization import TuneFunctionDecoder, TuneFunctionEncoder +from ray.util.annotations import DeveloperAPI +from ray.util.debug import log_once + +logger = logging.getLogger(__name__) + + +@DeveloperAPI +class TuneController: + CKPT_FILE_TMPL = "experiment_state-{}.json" + RAISE = "RAISE" + + def __init__( + self, + *, + search_alg: Optional[SearchAlgorithm] = None, + placeholder_resolvers: Optional[Dict[Tuple, Any]] = None, + scheduler: Optional[TrialScheduler] = None, + stopper: Optional[Stopper] = None, + resume_config: Optional[ResumeConfig] = None, + fail_fast: bool = False, + checkpoint_period: Union[str, int] = None, + callbacks: Optional[List[Callback]] = None, + metric: Optional[str] = None, + trial_checkpoint_config: Optional[CheckpointConfig] = None, + storage: Optional[StorageContext] = None, + reuse_actors: bool = False, + resource_manager_factory: Optional[Callable[[], ResourceManager]] = None, + _trainer_api: bool = False, + ): + if resource_manager_factory: + resource_manager = resource_manager_factory() + else: + resource_manager = PlacementGroupResourceManager() + + self._actor_manager = RayActorManager(resource_manager=resource_manager) + + self._class_cache = _ActorClassCache() + + # Resource status + self._resource_updater = _ResourceUpdater(None) + + # Actor <-> Trial mappings + self._actor_to_trial: Dict[TrackedActor, Trial] = {} + self._trial_to_actor: Dict[Trial, TrackedActor] = {} + + # Resources <-> Trial + self._resources_to_pending_trials: Dict[ + ResourceRequest, Set[Trial] + ] = defaultdict(set) + + # Keep track of actor states + self._pending_trials: Set[Trial] = set() + self._pending_trials_list: List[Trial] = [] + + self._running_trials: Set[Trial] = set() + + self._paused_trials: Set[Trial] = set() + + self._stopped_trials: Set[Trial] = set() + self._failed_trials: Set[Trial] = set() + + self._resetting_trials: Set[Trial] = set() + self._staged_trials: Set[Trial] = set() + + # Removed actors + self._started_actors: Set[TrackedActor] = set() + + # Map of tracked actors -> timestamp + # The timestamp is when we requested the stop. + # We track these actors here to force a + # cleanup after some time (as they might be hanging). + # Todo: This timeout logic should be moved into the actor manager. + # This map is populated whenever we request an actor stop: + # - Regular STOP decision + # - Removing an actor because its trial REUSEs a different trial's actor + # - Removing a cached actor because it's not needed anymore + # Actors are only tracked in this map if they actually started (not if they + # were only requested but never started). + # Actors are removed from this map: + # - When the STOP resolved and the actor actually stopped + # - When they are forcefully cleaned up after the timeout. + self._stopping_actors: Dict[TrackedActor, float] = {} + self._earliest_stopping_actor: float = float("inf") + self._actor_cleanup_timeout: int = int( + os.environ.get("TUNE_FORCE_TRIAL_CLEANUP_S", "600") + ) + self._actor_force_cleanup_timeout: int = 10 + + # Reuse actors + self._reuse_actors = reuse_actors + self._actor_cache = _ObjectCache(may_keep_one=True) + + # Trial metadata for experiment checkpoints + self._trials_to_cache: Set[Trial] = set() + self._trial_metadata: Dict[str, str] = {} + + # TRAINING + self._buffer_length = int(os.getenv("TUNE_RESULT_BUFFER_LENGTH", 1)) + self._buffer_min_time_s = float(os.getenv("TUNE_RESULT_BUFFER_MIN_TIME_S", 0.0)) + self._buffer_max_time_s = float( + os.getenv("TUNE_RESULT_BUFFER_MAX_TIME_S", 100.0) + ) + + # Legacy TrialRunner init + self._search_alg = search_alg or BasicVariantGenerator() + self._placeholder_resolvers = placeholder_resolvers + self._scheduler_alg = scheduler or FIFOScheduler() + self._callbacks = CallbackList(callbacks or []) + self._insufficient_resources_manager = _InsufficientResourcesManager( + for_train=_trainer_api + ) + self._pending_trial_queue_times = {} + + self._max_pending_trials = _get_max_pending_trials(self._search_alg) + + self._storage = storage + self._metric = metric + + self._total_time = 0 + self._iteration = 0 + self._has_errored = False + self._fail_fast = fail_fast + if isinstance(self._fail_fast, str): + self._fail_fast = self._fail_fast.upper() + if self._fail_fast == self.RAISE: + warnings.warn( + "fail_fast='raise' detected. Be careful when using this " + "mode as resources (such as Ray processes, " + "file descriptors, and temporary files) may not be " + "cleaned up properly. To use " + "a safer mode, use fail_fast=True." + ) + else: + raise ValueError( + "fail_fast must be one of {bool, RAISE}. " f"Got {self._fail_fast}." + ) + + self._print_trial_errors = bool( + int(os.environ.get("TUNE_PRINT_ALL_TRIAL_ERRORS", "1")) + ) + + self._trials: List[Trial] = [] + self._live_trials: Set[Trial] = set() # Set of non-terminated trials + self._cached_trial_decisions = {} + self._queued_trial_decisions = {} + + self._stop_queue = [] + self._should_stop_experiment = False # used by TuneServer + + self._stopper = stopper or NoopStopper() + + self._start_time = time.time() + + self._session_str = datetime.fromtimestamp(self._start_time).strftime( + "%Y-%m-%d_%H-%M-%S" + ) + + if checkpoint_period is None: + checkpoint_period = os.getenv("TUNE_GLOBAL_CHECKPOINT_S", "auto") + + self._checkpoint_period = checkpoint_period + self._trial_checkpoint_config = trial_checkpoint_config or CheckpointConfig() + self._checkpoint_manager = self._create_checkpoint_manager() + + self._resumed = False + + if resume_config is not None: + # Use the metadata file to restore TuneController state + try: + self.resume(resume_config=resume_config) + self._resumed = True + except Exception as e: + if has_verbosity(Verbosity.V3_TRIAL_DETAILS): + logger.error(str(e)) + logger.exception("Failed to restore the run state.") + if self._fail_fast: + raise + logger.info("Restarting experiment.") + else: + logger.debug("Starting a new experiment.") + + def _wrapped(self): + """Return wrapped tune controller to be passed to scheduler/searchers.""" + return TrialRunnerWrapper( + self, + trial_executor=_FakeRayTrialExecutor(self), + runner_whitelist_attr={ + "search_alg", + "get_trials", + "get_live_trials", + "_set_trial_status", + "pause_trial", + "stop_trial", + "_schedule_trial_save", + }, + executor_whitelist_attr={ + "has_resources_for_trial", + "pause_trial", + "save", + "_resource_updater", + }, + ) + + @property + def resumed(self): + return self._resumed + + @property + def search_alg(self): + return self._search_alg + + @property + def scheduler_alg(self): + return self._scheduler_alg + + def setup_experiments( + self, experiments: List[Experiment], total_num_samples: int + ) -> None: + """Obtains any necessary information from experiments. + + Mainly used to setup callbacks. + + Args: + experiments: List of Experiments + to use. + total_num_samples: Total number of samples + factoring in grid search samplers. + """ + experiment = experiments[0] + spec = experiment.public_spec if experiment else {} + spec["total_num_samples"] = total_num_samples + self._callbacks.setup(**spec) + + def end_experiment_callbacks(self) -> None: + """Calls ``on_experiment_end`` method in callbacks.""" + self._callbacks.on_experiment_end(trials=self._trials) + + @property + def experiment_state_file_name(self) -> str: + return self.CKPT_FILE_TMPL.format(self._session_str) + + @property + def experiment_state_path(self) -> str: + """Returns the local experiment checkpoint path.""" + return Path( + self._storage.experiment_driver_staging_path, + self.experiment_state_file_name, + ).as_posix() + + @property + def experiment_path(self) -> str: + return self._storage.experiment_fs_path + + def _create_checkpoint_manager(self): + return _ExperimentCheckpointManager( + storage=self._storage, + checkpoint_period=self._checkpoint_period, + sync_every_n_trial_checkpoints=self._trial_checkpoint_config.num_to_keep, + ) + + def save_to_dir(self): + """Save TuneController state to the local staging experiment directory. + + This includes: + - trial states + - TuneController internal state (all the serializable attributes) + - the searcher state + - the callback states + """ + # Get state from trial executor and runner + runner_state = { + # Trials + "trial_data": list(self._get_trial_checkpoints().values()), + # Experiment data + "runner_data": self.__getstate__(), + # Metadata + "stats": {"start_time": self._start_time}, + } + + driver_staging_path = self._storage.experiment_driver_staging_path + os.makedirs(driver_staging_path, exist_ok=True) + with open( + Path(driver_staging_path, self.experiment_state_file_name), + "w", + ) as f: + json.dump(runner_state, f, cls=TuneFunctionEncoder) + + self._search_alg.save_to_dir(driver_staging_path, session_str=self._session_str) + self._callbacks.save_to_dir(driver_staging_path, session_str=self._session_str) + + def checkpoint(self, force: bool = False, wait: bool = False): + self._checkpoint_manager.sync_up_experiment_state( + save_fn=self.save_to_dir, force=force, wait=wait + ) + + def _requeue_restored_trials( + self, trials: List[Trial], resume_config: ResumeConfig + ): + # Set trial statuses according to the resume configuration + for trial in sorted( + trials, key=lambda t: t.run_metadata.last_result_time, reverse=True + ): + if trial.status == Trial.ERROR: + resume_type = resume_config.errored + elif trial.status == Trial.TERMINATED: + resume_type = resume_config.finished + else: # Unfinished (PENDING, RUNNING, PAUSED) + resume_type = resume_config.unfinished + + trial_to_add = None + if resume_type == ResumeConfig.ResumeType.RESUME: + # Keep trial ID on resume + trial_to_add = trial + trial_to_add.run_metadata.error_filename = None + trial_to_add.run_metadata.pickled_error_filename = None + trial_to_add.set_status(Trial.PENDING) + elif resume_type == ResumeConfig.ResumeType.RESTART: + trial_to_add = trial.reset() + trial_to_add.restore_path = None + elif resume_type == ResumeConfig.ResumeType.SKIP: + trial_to_add = trial + if trial_to_add.status != Trial.ERROR: + # Set the status to terminated to skip it. + # Keep errored trial status as ERROR. + trial_to_add.set_status(Trial.TERMINATED) + else: + raise ValueError(f"Unknown resume type: {resume_type}") + assert trial_to_add is not None + + self.add_trial(trial_to_add) + + def _restore_trials(self, experiment_state: Dict) -> List[Trial]: + trials = [] + for trial_json_state, trial_runtime_metadata in experiment_state["trial_data"]: + trial = Trial.from_json_state(trial_json_state) + trial.restore_run_metadata(trial_runtime_metadata) + + # The following properties may be updated on restoration + # Ex: moved local/cloud experiment directory + + # Propagate updated storage ctx properties to the trial's restored copy. + new_storage = copy.copy(trial.storage) + new_storage.storage_filesystem = self._storage.storage_filesystem + new_storage.storage_fs_path = self._storage.storage_fs_path + new_storage.experiment_dir_name = self._storage.experiment_dir_name + + # ATTN: `trial.set_storage` is used intentionally, since it + # also updates the absolute paths and filesystem of tracked checkpoints. + trial.set_storage(new_storage) + + # Avoid creating logdir in client mode for returned trial results, + # since the dir might not be creatable locally. + # TODO(ekl) this is kind of a hack. + if not ray.util.client.ray.is_connected(): + trial.init_local_path() # Create logdir if it does not exist + + trials.append(trial) + + # NOTE: The restored run should reuse the same driver staging directory. + self._storage._timestamp = trials[0].storage._timestamp + + return trials + + def resume(self, resume_config: ResumeConfig): + """Resumes all checkpointed trials from previous run. + + Requires user to manually re-register their objects. Also stops + all ongoing trials. + """ + # 1. Restore TuneController state + # Find newest state file + newest_state_path = _find_newest_experiment_checkpoint( + self._storage.experiment_fs_path, fs=self._storage.storage_filesystem + ) + + if newest_state_path is None: + raise ValueError( + f"Tried to resume experiment from directory " + f"'{self._storage.experiment_fs_path}', but no " + f"experiment state file of the form '{TuneController.CKPT_FILE_TMPL}' " + "was found. This is expected if you are launching a new experiment." + ) + + logger.info( + "Restoring the run from the latest experiment state file: " + f"{Path(newest_state_path).name}" + ) + with self._storage.storage_filesystem.open_input_stream(newest_state_path) as f: + experiment_state = json.loads(f.readall(), cls=TuneFunctionDecoder) + + self.__setstate__(experiment_state["runner_data"]) + + # 2. Get the trial states that the run left off at. + trials = self._restore_trials(experiment_state) + + # 3. Restore search algorithm and callback state + # Download the search algorithm and callback state to the driver staging dir. + self._checkpoint_manager.sync_down_experiment_state() + + driver_staging_dir = self._storage.experiment_driver_staging_path + if self._search_alg.has_checkpoint(driver_staging_dir): + self._search_alg.restore_from_dir(driver_staging_dir) + + if self._callbacks.can_restore(driver_staging_dir): + self._callbacks.restore_from_dir(driver_staging_dir) + + # 4. Re-queue trials as needed, depending on their status. + self._requeue_restored_trials(trials, resume_config) + + def update_max_pending_trials(self, max_pending_trials: Optional[int] = None): + self._max_pending_trials = max_pending_trials or _get_max_pending_trials( + self._search_alg + ) + + def update_pending_trial_resources( + self, resources: Union[dict, PlacementGroupFactory] + ): + """Update trial resources when resuming from checkpoint. + + Only updating the pending ones. + """ + assert resources + if isinstance(resources, dict) and "gpu" not in resources: + resources["gpu"] = 0 + for trial in self._trials: + if trial.status == Trial.PENDING: + trial.update_resources(resources=resources) + + def is_finished(self): + """Returns whether all trials have finished running.""" + # The checks here are partly redundant but optimized for quick + # evaluation. Specifically, if there are live trials, we check + # these live trials first. Only if none of the live trials is + # live anymore do we loop over all trials for a final check. + trials_done = ( + len(self._live_trials) == 0 + or all(trial.is_finished() for trial in self._live_trials) + ) and all(trial.is_finished() for trial in self._trials) + return trials_done and self._search_alg.is_finished() + + def get_trial(self, tid): + trial = [t for t in self._trials if t.trial_id == tid] + return trial[0] if trial else None + + def get_trials(self): + """Returns the list of trials managed by this TrialRunner. + + Note that the caller usually should not mutate trial state directly. + """ + return self._trials + + def get_live_trials(self): + """Returns the set of trials that are not in Trial.TERMINATED state.""" + return self._live_trials + + def add_trial(self, trial: Trial): + """Adds a new trial to this TrialRunner. + + Trials may be added at any time. + + Args: + trial: Trial to queue. + """ + # If the config map has had all the references replaced with placeholders, + # resolve them before adding the trial. + if self._placeholder_resolvers: + trial.resolve_config_placeholders(self._placeholder_resolvers) + + # With trial.config resolved, create placement group factory if needed. + trial.create_placement_group_factory() + + self._trials.append(trial) + if trial.status != Trial.TERMINATED: + self._live_trials.add(trial) + with warn_if_slow("scheduler.on_trial_add"): + self._scheduler_alg.on_trial_add(self._wrapped(), trial) + self._mark_trial_to_checkpoint(trial) + + logger.debug(f"Adding trial {trial} with status {trial.status}") + + status_str_map = { + Trial.PENDING: self._pending_trials, + Trial.RUNNING: self._running_trials, + Trial.PAUSED: self._paused_trials, + Trial.TERMINATED: self._stopped_trials, + Trial.ERROR: self._failed_trials, + } + + status_str_map[trial.status].add(trial) + + if trial.status == Trial.PENDING: + self._pending_trials_list.append(trial) + self._resources_to_pending_trials[trial.placement_group_factory].add(trial) + + def _update_trial_queue(self, blocking: bool = False, timeout: int = 600) -> bool: + """Adds next trials to queue if possible. + + Note that the timeout is currently unexposed to the user. + + Args: + blocking: Blocks until either a trial is available + or is_finished (timeout or search algorithm finishes). + timeout: Seconds before blocking times out. + + Returns: + Boolean indicating if a new trial was created or not. + """ + trial = self._search_alg.next_trial() + if blocking and not trial: + start = time.time() + # Checking `is_finished` instead of _search_alg.is_finished + # is fine because blocking only occurs if all trials are + # finished and search_algorithm is not yet finished + while ( + not trial and not self.is_finished() and time.time() - start < timeout + ): + logger.debug("Blocking for next trial...") + trial = self._search_alg.next_trial() + time.sleep(1) + + if trial: + self.add_trial(trial) + return True + + return False + + def _used_resources_string(self) -> str: + allocated_resources = self._actor_manager.get_live_actors_resources() + + return self._resource_updater.debug_string(allocated_resources) + + def on_step_begin(self): + self._resource_updater.update_avail_resources() + + def on_step_end(self): + self._cleanup_cached_actors(force_all=False) + self._cleanup_stopping_actors(force_all=False) + + def _cleanup_cached_actors(self, force_all: bool = False): + if ( + self._search_alg.is_finished() + and not self._staged_trials + and self._actor_cache.total_max_objects == 0 + ): + # If there are no more trials coming in, no trials are pending execution, + # and we don't explicitly want to cache objects, we can evict the full + # cache. + force_all = True + + for tracked_actor in self._actor_cache.flush_cached_objects( + force_all=force_all + ): + logger.debug(f"Cleaning up cached actor: {tracked_actor}") + # Unset termination callbacks as no trial is associated + tracked_actor.set_on_stop(None) + tracked_actor.set_on_error(None) + self._remove_actor(tracked_actor=tracked_actor) + + def _cleanup_stopping_actors(self, force_all: bool = False): + now = time.monotonic() + + if ( + not force_all + and now - self._earliest_stopping_actor <= self._actor_cleanup_timeout + ): + # If the earliest actor to timeout has not reached the timeout, return + return + + # This is a bit costly, so we want to avoid running it too often + times = deque( + sorted( + [ + (timestamp, tracked_actor) + for tracked_actor, timestamp in self._stopping_actors.items() + ], + key=lambda item: item[0], + ) + ) + + while times and ( + force_all or time.monotonic() - times[0][0] > self._actor_cleanup_timeout + ): + if ( + time.monotonic() - times[0][0] < self._actor_force_cleanup_timeout + ) and self._actor_manager.is_actor_started(tracked_actor=times[0][1]): + # Even if force_all=True, we give the actors time to clean up + self._actor_manager.next(timeout=1) + continue + + _, tracked_actor = times.popleft() + + if tracked_actor not in self._stopping_actors: + # Actor stopping has been handled by the block above + continue + + if self._actor_manager.is_actor_started(tracked_actor=tracked_actor): + logger.debug(f"Forcefully killing actor: {tracked_actor}") + self._actor_manager.remove_actor(tracked_actor=tracked_actor, kill=True) + self._stopping_actors.pop(tracked_actor) + + if times: + self._earliest_stopping_actor = times[0][0] + else: + self._earliest_stopping_actor = float("inf") + + def step(self): + if self.is_finished(): + raise TuneError("Called step when all trials finished?") + + with warn_if_slow("on_step_begin"): + self.on_step_begin() + + with warn_if_slow("callbacks.on_step_begin"): + self._callbacks.on_step_begin( + iteration=self._iteration, trials=self._trials + ) + + # Ask searcher for more trials + self._maybe_update_trial_queue() + + # Start actors for added trials + self._maybe_add_actors() + + # Handle one event + if not self._actor_manager.next(timeout=0.1): + # If there are no actors running, warn about potentially + # insufficient resources + if not self._actor_manager.num_live_actors: + self._insufficient_resources_manager.on_no_available_trials( + self.get_trials() + ) + + # Maybe stop whole experiment + self._stop_experiment_if_needed() + + # Maybe save experiment state + try: + self.checkpoint() + except Exception as e: + logger.warning(f"Trial controller checkpointing failed: {str(e)}") + raise e + + self._iteration += 1 + + with warn_if_slow("on_step_end"): + self.on_step_end() + with warn_if_slow("callbacks.on_step_end"): + self._callbacks.on_step_end(iteration=self._iteration, trials=self._trials) + + def _set_trial_status(self, trial: Trial, status: str): + """Set trial to a specific status. + + This will keep track of trials with specific statuses in sets. + + For PENDING and PAUSED trials we also keep a list of trials to be able + to retain FIFO ordering. See ``_maybe_add_actors`` for details. + + Lastly we also keep a mapping from resources to pending/paused trials + to be able to efficiently start trials for cached actors. + """ + current_status = trial.status + + if current_status == status: + logger.debug(f"Trial {trial} already has status {status}. Skipping update.") + return + + status_str_map = { + Trial.PENDING: self._pending_trials, + Trial.RUNNING: self._running_trials, + Trial.PAUSED: self._paused_trials, + Trial.TERMINATED: self._stopped_trials, + Trial.ERROR: self._failed_trials, + } + + logger.debug( + f"Setting status for trial {trial} from {current_status} to {status}" + ) + + assert trial in status_str_map[current_status], (trial, current_status) + assert trial not in status_str_map[status], (trial, status) + + status_str_map[current_status].remove(trial) + status_str_map[status].add(trial) + + # We keep a log for pending trials for FIFO scheduling. + # We do not need to remove from this list as we will just discard + # items that are in this list but not in the respective set. + if status == Trial.PENDING: + self._pending_trials_list.append(trial) + self._resources_to_pending_trials[trial.placement_group_factory].add(trial) + else: + self._resources_to_pending_trials[trial.placement_group_factory].discard( + trial + ) + + trial.set_status(status) + + def _get_trial_checkpoints(self) -> Dict[str, str]: + for trial in self._trials_to_cache: + self._trial_metadata[trial.trial_id] = trial.get_json_state() + self._trials_to_cache.clear() + return self._trial_metadata + + def _mark_trial_to_checkpoint(self, trial: Trial): + self._trials_to_cache.add(trial) + + ### + # UPDATE TRIALS + def _maybe_update_trial_queue(self): + """Ask the searcher for more trials.""" + if self._search_alg.is_finished(): + return + + dont_wait_for_trial = ( + self._pending_trials or self._running_trials or self._paused_trials + ) + + while len(self._pending_trials) < self._max_pending_trials: + if not self._update_trial_queue(blocking=not dont_wait_for_trial): + break + dont_wait_for_trial = True + + def _cleanup_trials(self): + logger.debug("CLEANING UP all trials") + + for tracked_actor in list(self._actor_to_trial): + trial = self._actor_to_trial[tracked_actor] + logger.debug( + f"Scheduling trial stop at end of experiment (trial {trial}): " + f"{tracked_actor}" + ) + self._schedule_trial_stop(trial) + + # Clean up cached actors now + self._cleanup_cached_actors(force_all=True) + + start = time.monotonic() + while time.monotonic() - start < 5 and self._actor_manager.num_total_actors: + if _dedup_logs("actor_manager_cleanup", str(start)): + logger.debug( + "Waiting for actor manager to clean up final state [dedup]" + ) + self._actor_manager.next(timeout=1) + + logger.debug("Force cleanup of remaining actors") + self._cleanup_stopping_actors(force_all=True) + + self._actor_manager.cleanup() + + def _remove_actor(self, tracked_actor: TrackedActor): + stop_future = self._actor_manager.schedule_actor_task( + tracked_actor, "stop", _return_future=True + ) + now = time.monotonic() + + if self._actor_manager.remove_actor( + tracked_actor, kill=False, stop_future=stop_future + ): + # If the actor was previously alive, track + self._stopping_actors[tracked_actor] = now + self._earliest_stopping_actor = min(self._earliest_stopping_actor, now) + + ### + # ADD ACTORS + def _maybe_add_actors(self) -> None: + """Add actors for pending and paused trials. + + For actors that have not been staged, yet, we request an actor. + + For actors that have been staged, already, we try to reuse a cached actor. + + First, we handle the trial that the scheduler chooses to run. + + Then, we handle all trials that are pending. + + Lastly, we see if we have cached actors that we can assign to a pending or + paused trial. This can be the case when a trial has not been staged, yet, + for instance because the number of staging trials was too large. + """ + + ### + # 1: Start trial that the scheduler wants to run + with warn_if_slow("choose_trial_to_run"): + trial_to_run = self._scheduler_alg.choose_trial_to_run(self._wrapped()) + + if trial_to_run: + if _dedup_logs("trial_to_run_chosen", trial_to_run.trial_id): + logger.debug( + f"Chose trial to run from scheduler: {trial_to_run} [dedup]" + ) + if ( + trial_to_run not in self._staged_trials + and trial_to_run not in self._trial_to_actor + ): + logger.debug(f"Staging trial to run: {trial_to_run}") + self._set_trial_status(trial_to_run, Trial.PENDING) + self._staged_trials.add(trial_to_run) + self._actor_cache.increase_max(trial_to_run.placement_group_factory) + # schedule_trial_actor also potentially uses cached actors + self._schedule_trial_actor(trial_to_run) + else: + # Otherwise, only try to use the cached actor + if _dedup_logs("trial_to_run_reuse", trial_to_run.trial_id): + logger.debug( + f"Trying to re-use actor for trial to run: {trial_to_run} " + f"[dedup]" + ) + self._maybe_reuse_cached_actor(trial_to_run) + + ### + # 2: Start trials that are PENDING + def _maybe_add_actors(candidates: List[Trial]): + new_candidates = [] + + while candidates: + if self._actor_manager.num_pending_actors >= self._max_pending_trials: + break + + trial = candidates.pop(0) + + # If the trial is part of the list, but not of the set, + # we just ignore it. Removing it from the list on status + # change is too expensive. + if trial not in self._pending_trials: + continue + + if trial in self._trial_to_actor: + new_candidates.append(trial) + continue + + if trial in self._staged_trials: + self._maybe_reuse_cached_actor(trial) + continue + + logger.debug(f"Scheduling actor for enqueued trial: {trial}") + self._staged_trials.add(trial) + self._actor_cache.increase_max(trial.placement_group_factory) + self._schedule_trial_actor(trial) + + return new_candidates + candidates + + self._pending_trials_list = _maybe_add_actors(self._pending_trials_list) + + ### + # 3: Start any trial that can be started with a cached actor + if self._actor_cache.num_cached_objects: + for resource in self._resources_to_pending_trials: + if not self._resources_to_pending_trials[resource]: + continue + + if not self._actor_cache.has_cached_object(resource): + continue + + start_trial = self._resources_to_pending_trials[resource].pop() + logger.debug( + f"Trying to re-use actor for enqueued trial: {start_trial}" + ) + if not self._maybe_reuse_cached_actor(start_trial): + self._resources_to_pending_trials[resource].add(start_trial) + else: + if start_trial not in self._staged_trials: + self._staged_trials.add(start_trial) + self._actor_cache.increase_max( + start_trial.placement_group_factory + ) + + def _maybe_reuse_cached_actor(self, trial: Trial) -> bool: + """Maybe reuse a cached actor for a trial. + + If an actor has been scheduled for the trial already, + this will remove the original actor. + """ + if trial in self._resetting_trials: + return True + + resource_request = trial.placement_group_factory + + if not self._actor_cache.has_cached_object(resource_request): + return False + + cached_actor = self._actor_cache.pop_cached_object(resource_request) + logger.debug(f"Reusing ACTOR for trial {trial}: {cached_actor}") + + if trial in self._trial_to_actor: + original_actor = self._trial_to_actor.pop(trial) + self._actor_to_trial.pop(original_actor) + + logger.debug(f"Removing ORIGINAL ACTOR for trial {trial}: {original_actor}") + self._remove_actor(tracked_actor=original_actor) + + self._trial_to_actor[trial] = cached_actor + self._actor_to_trial[cached_actor] = trial + + # Todo: get rid of Trial.runner + ray_actor = self._actor_manager._live_actors_to_ray_actors_resources[ + cached_actor + ][0] + trial.set_ray_actor(ray_actor) + + self._schedule_trial_reset(trial, trial.config, trial.experiment_tag) + + return True + + def _schedule_trial_actor(self, trial: Trial): + """Schedule an actor for a trial. + + If a cached actor is available, use it. Otherwise, request a + new actor. + """ + logger.debug(f"Trying to schedule new ACTOR for trial {trial}") + + assert trial.status == Trial.PENDING + + trial.init_local_path() + # We checkpoint metadata here to try mitigating logdir duplication + self._mark_trial_to_checkpoint(trial) + + if self._maybe_reuse_cached_actor(trial): + return + + # Safeguard + if trial in self._trial_to_actor: + raise RuntimeError( + f"Tried to request a new actor for trial {trial}, but an old " + f"actor still exists. This can lead to leaked resources. The old " + f"actor should be removed first. " + f"This is an internal problem in Ray Tune. If you encounter this " + f"error, please raise an issue on " + f"https://github.com/ray-project/ray/issues" + ) + + trainable_cls = trial.get_trainable_cls() + if not trainable_cls: + exception = _AbortTrialExecution( + f"Invalid trainable: {trial.trainable_name}. If you passed " + f"a string, make sure the trainable was registered before." + ) + trial.handle_error(exception) + self._schedule_trial_stop(trial, exception=exception) + return + + _actor_cls = self._class_cache.get(trainable_cls) + + trial.set_location(_Location()) + trainable_kwargs = _get_trainable_kwargs(trial=trial) + + with _change_working_directory(trial): + tracked_actor = self._actor_manager.add_actor( + cls=_actor_cls, + resource_request=trial.placement_group_factory, + kwargs=trainable_kwargs, + on_start=self._actor_started, + on_stop=self._actor_stopped, + on_error=self._actor_failed, + ) + self._trial_to_actor[trial] = tracked_actor + self._actor_to_trial[tracked_actor] = trial + + logger.debug( + f"Scheduled new ACTOR for trial {trial}: {tracked_actor}. " + f"Resources: {trial.placement_group_factory}" + ) + + def _unstage_trial_with_resources(self, trial: Trial): + """Unstage trial, or one with the same resources as ``trial``.""" + # Case 1: The trial we started was staged. Just remove it + if trial in self._staged_trials: + self._staged_trials.remove(trial) + self._actor_cache.decrease_max(trial.placement_group_factory) + return + + # Case 2: We staged a trial "A" with the same resources, but our trial "B" + # was selected by the scheduler to run. The resource manager does not care + # about "trials", it just cares about resources being available. Thus we + # look for a staged trial with the same resource requirements and remove it + + resource_request = trial.placement_group_factory + # Remove staged trial with same resource requirements + candidate_trial = None + for staged_trial in self._staged_trials: + staged_resources = staged_trial.placement_group_factory + if staged_resources == resource_request: + candidate_trial = staged_trial + break + + if candidate_trial: + self._staged_trials.remove(candidate_trial) + self._actor_cache.decrease_max(candidate_trial.placement_group_factory) + return + + raise RuntimeError( + "Started a trial with resources requested by a different trial, but " + "this trial was lost. This is an error in Ray Tune's execution " + "logic. Please raise a GitHub issue at " + "https://github.com/ray-project/ray/issues" + ) + + def _maybe_cache_trial_actor(self, trial: Trial) -> bool: + """Cache trial actor for reuse, if needed. + + We will only cache as many actors as are needed to fulfill any pending + resource requests for actors with the same resource requirements. + E.g. if we have 6 running trials and 4 additional staged actors, we will only + cache up to 4 of the running trial actors when they finish. + + One exception is the case when we have no cached actors, yet. In that case, + we will always cache the actor in this method. + + Later, in `_cleanup_cached_actors`, we will check again if we need this cached + actor. That method will keep the actor if we don't have any staged trials, + because we don't know at that point if the next trial might require the same + resources. But because there is no staged trial, it is safe to keep the actor + around, as it won't occupy resources needed by another trial until it's staged. + """ + if not self._reuse_actors: + return False + + if self._search_alg.is_finished() and not self._staged_trials: + logger.debug( + f"Not caching actor of trial {trial} as the search is over " + f"and no more trials are staged." + ) + return False + + tracked_actor = self._trial_to_actor[trial] + + if ( + not self._actor_manager.is_actor_started(tracked_actor) + or self._actor_manager.is_actor_failed(tracked_actor) + or tracked_actor not in self._started_actors + ): + logger.debug( + f"Not caching actor of trial {trial} as it has not been started, yet: " + f"{tracked_actor}" + ) + return False + + if not self._actor_cache.cache_object( + trial.placement_group_factory, tracked_actor + ): + logger.debug( + f"Could not cache actor of trial {trial} for " + "reuse, as there are no pending trials " + "requiring its resources." + ) + return False + + logger.debug(f"Caching actor of trial {trial} for re-use: {tracked_actor}") + + tracked_actor = self._trial_to_actor.pop(trial) + self._actor_to_trial.pop(tracked_actor) + + trial.set_ray_actor(None) + + return True + + def _actor_started(self, tracked_actor: TrackedActor, log: str = "STARTED"): + self._started_actors.add(tracked_actor) + + trial = self._actor_to_trial[tracked_actor] + + logger.debug(f"Actor {log} for trial {trial}: {tracked_actor}") + + self._unstage_trial_with_resources(trial) + + ray_actor = self._actor_manager._live_actors_to_ray_actors_resources[ + tracked_actor + ][0] + trial.set_ray_actor(ray_actor) + + self._callbacks.on_trial_start( + iteration=self._iteration, trials=self._trials, trial=trial + ) + + self._set_trial_status(trial, Trial.RUNNING) + + self._mark_trial_to_checkpoint(trial) + + if not self._schedule_trial_restore(trial): + self._schedule_trial_train(trial) + + def _actor_stopped(self, tracked_actor: TrackedActor): + if tracked_actor in self._actor_to_trial: + trial = self._actor_to_trial.pop(tracked_actor) + logger.debug(f"Actor STOPPED for trial {trial}: {tracked_actor}") + self._trial_to_actor.pop(trial) + trial.set_ray_actor(None) + + logger.debug(f"Actor STOPPED: {tracked_actor}") + + self._stopping_actors.pop(tracked_actor, None) + self._started_actors.discard(tracked_actor) + + def _actor_failed(self, tracked_actor: TrackedActor, exception: Exception): + trial = self._actor_to_trial[tracked_actor] + + logger.debug( + f"Actor FAILED for trial {trial}: {tracked_actor}. " + f"Exception: {exception}" + ) + + if trial in (self._pending_trials | self._paused_trials): + # First, set to running (needed downstream in _process_trial_failure) + self._set_trial_status(trial, Trial.RUNNING) + + logger.debug( + f"Trial {trial} failed in its creation task. Unstaging " + f"to allow it to be re-scheduled." + ) + + self._unstage_trial_with_resources(trial) + self._trial_task_failure(trial, exception=exception) + + self._actor_manager.clear_actor_task_futures(tracked_actor) + + # Clean up actor + tracked_actor.set_on_stop(None) + tracked_actor.set_on_error(None) + self._actor_manager.remove_actor(tracked_actor, kill=False) + + # Trigger actor stopped callback + self._actor_stopped(tracked_actor) + + def _schedule_trial_task( + self, + trial: Trial, + method_name: str, + args: Optional[Tuple] = None, + kwargs: Optional[Dict] = None, + on_result: Optional[Callable[[Trial, Any], None]] = None, + on_error: Optional[Callable[[Trial, Exception], None]] = None, + _return_future: bool = False, + ) -> Optional[ray.ObjectRef]: + """Schedule an actor task future for a trial. + + This is a wrapper around ``ActorManager.schedule_actor_task``. This method + retrieves the tracked actor for a trial to kick off the task. + + It also wraps around the callbacks, retrieving the trial object given the + tracked actor. + """ + + tracked_actor = self._trial_to_actor[trial] + + _on_result = None + _on_error = None + + args = args or tuple() + kwargs = kwargs or {} + + if on_result: + + def _on_result(tracked_actor: TrackedActor, *args, **kwargs): + assert trial == self._actor_to_trial[tracked_actor] + logger.debug( + f"Future {method_name.upper()} RESOLVED for trial {trial}: " + f"{args}, {kwargs}" + ) + try: + on_result(trial, *args, **kwargs) + except Exception as e: + logger.debug( + f"Error handling {method_name.upper()} result " + f"for trial {trial}: {e}" + ) + if e is TuneError or self._fail_fast == self.RAISE: + raise e + else: + raise TuneError(traceback.format_exc()) + + if on_error: + + def _on_error(tracked_actor: TrackedActor, exception: Exception): + # If the actor failed, it has already been cleaned up. + if tracked_actor not in self._actor_to_trial: + assert isinstance(exception, RayActorError), type(exception) + else: + assert trial == self._actor_to_trial[tracked_actor] + + logger.debug( + f"Future {method_name.upper()} FAILED for trial {trial}: " + f"{exception}" + ) + try: + on_error(trial, exception) + except Exception as e: + logger.debug( + f"Error handling {method_name.upper()} failure " + f"for trial {trial}: {e}" + ) + if e is TuneError or self._fail_fast == self.RAISE: + raise e + else: + raise TuneError(traceback.format_exc()) + + logger.debug(f"Future {method_name.upper()} SCHEDULED for trial {trial}") + + with _change_working_directory(trial): + future = self._actor_manager.schedule_actor_task( + tracked_actor=tracked_actor, + method_name=method_name, + args=args, + kwargs=kwargs, + on_result=_on_result, + on_error=_on_error, + _return_future=_return_future, + ) + if _return_future: + return future + + def _queue_decision(self, trial, decision): + # Get old decision, setting it to the current decision if it isn't set + old_decision = self._queued_trial_decisions.setdefault(trial.trial_id, decision) + + # Stopping always takes precedence. If we decided to stop, just quit + if old_decision is TrialScheduler.STOP: + return + + # The old decision wasn't STOP. We update the decision only if it is + # STOP or PAUSE. The action will only be CONTINUE if it was set by + # the first received result and was never updated after that. + if decision is TrialScheduler.STOP or decision is TrialScheduler.PAUSE: + self._queued_trial_decisions[trial.trial_id] = decision + + def _execute_action(self, trial: Trial, decision: str, after_save: bool = False): + """Executes action based on decision. + + Args: + trial: Trial to act on. + decision: Scheduling decision to undertake. + """ + if decision == TrialScheduler.CONTINUE: + self._schedule_trial_train(trial) + elif decision == TrialScheduler.PAUSE: + self.pause_trial(trial, should_checkpoint=not after_save) + elif decision == TrialScheduler.STOP: + self.stop_trial(trial) + elif decision == TrialScheduler.NOOP: + pass + else: + raise ValueError("Invalid decision: {}".format(decision)) + + def _maybe_execute_queued_decision(self, trial: Trial, after_save: bool = False): + # `self._queued_trial_decisions` now contains a final decision + # based on all results + final_decision = self._queued_trial_decisions.pop(trial.trial_id, None) + if final_decision: + logger.debug( + f"Executing final queued decision for {trial}: {final_decision}" + ) + self._execute_action(trial, final_decision, after_save=after_save) + + def _stop_experiment_if_needed(self): + """Stops all trials.""" + fail_fast = self._fail_fast and self._has_errored + if self._stopper.stop_all() or fail_fast or self._should_stop_experiment: + self._search_alg.set_finished() + [ + self._schedule_trial_stop(t) + for t in self._trials + if t.status not in {Trial.ERROR, Trial.TERMINATED} + ] + + ### + # Failure + def _trial_task_failure(self, trial: Trial, exception: Exception): + if self._fail_fast == self.RAISE: + raise exception + else: + if self._print_trial_errors: + logger.error(f"Trial task failed for trial {trial}", exc_info=exception) + self._process_trial_failure(trial, exception=exception) + + def _process_trial_failure( + self, + trial: Trial, + exception: Union[TuneError, RayTaskError, RayActorError], + ): + """Handle trial failure. + + Attempt trial recovery if possible, clean up state otherwise. + + Args: + trial: Failed trial. + exception: Exception prior to invoking this method. + """ + self._has_errored = True + trial.handle_error(exception) + if trial.status == Trial.RUNNING and trial.should_recover(): + self._try_recover(trial, exc=exception) + self._callbacks.on_trial_recover( + iteration=self._iteration, trials=self._trials, trial=trial + ) + elif trial.status in {Trial.RUNNING, Trial.PENDING}: + self._scheduler_alg.on_trial_error(self, trial) + self._search_alg.on_trial_complete(trial.trial_id, error=True) + self._schedule_trial_stop(trial, exception=exception) + self._callbacks.on_trial_error( + iteration=self._iteration, trials=self._trials, trial=trial + ) + + def _schedule_trial_stop(self, trial: Trial, exception: Optional[Exception] = None): + if trial.status == Trial.ERROR: + logger.debug(f"Not requesting trial STOP as it is ERROR already: {trial}") + return + + logger.debug(f"Requesting to STOP actor for trial {trial}") + + if trial.is_saving: + logger.debug( + f"Trial {trial} is currently saving/pausing. Scheduling STOP after " + f"save resolved." + ) + self._cached_trial_decisions[trial.trial_id] = TrialScheduler.STOP + + trial.temporary_state.saving_to = None + trial.temporary_state.restoring_from = None + + self._set_trial_status(trial, Trial.ERROR if exception else Trial.TERMINATED) + trial.set_location(_Location()) + + if trial not in self._trial_to_actor: + logger.debug(f"Will not STOP trial actor as it is not live: {trial}") + return + + tracked_actor = self._trial_to_actor[trial] + + self._actor_manager.clear_actor_task_futures(tracked_actor=tracked_actor) + + self._mark_trial_to_checkpoint(trial) + + if not exception and self._maybe_cache_trial_actor(trial): + # Trial runner has been cached + return + + logger.debug(f"Terminating actor for trial {trial}: {tracked_actor}") + + tracked_actor = self._trial_to_actor.pop(trial) + self._actor_to_trial.pop(tracked_actor) + + trial.set_ray_actor(None) + + self._remove_actor(tracked_actor=tracked_actor) + + def stop_trial(self, trial): + """The canonical implementation of stopping a trial. + + Trials may be in any external status when this function is called. + If trial is in state PENDING or PAUSED, calls `on_trial_remove` for + scheduler and `on_trial_complete()` for search_alg. + If trial is in state RUNNING, calls `on_trial_complete` for scheduler + and search_alg if RUNNING. Caller to ensure that there is no + outstanding future to be handled for the trial. If there is, the future + would be discarded. + """ + try: + if trial.status in [Trial.ERROR, Trial.TERMINATED]: + return + elif trial.status in [Trial.PENDING, Trial.PAUSED]: + self._scheduler_alg.on_trial_remove(self, trial) + self._search_alg.on_trial_complete(trial.trial_id) + elif trial.status is Trial.RUNNING: + # By this time trial.last_result should have been + # updated already. + self._scheduler_alg.on_trial_complete( + self, trial, flatten_dict(trial.last_result) + ) + self._search_alg.on_trial_complete( + trial.trial_id, result=flatten_dict(trial.last_result) + ) + self._callbacks.on_trial_complete( + iteration=self._iteration, trials=self._trials, trial=trial + ) + self._schedule_graceful_trial_stop(trial) + self._live_trials.discard(trial) + except Exception as e: + logger.exception("Trial %s: Error stopping trial.", trial) + if self._fail_fast == self.RAISE: + raise + if isinstance(e, TuneError): + self._process_trial_failure(trial, exception=e) + else: + self._process_trial_failure( + trial, _TuneStopTrialError(traceback.format_exc()) + ) + + def _schedule_graceful_trial_stop(self, trial: Trial): + self._schedule_trial_export(trial) + if trial.status != "ERROR": + self._schedule_trial_stop(trial) + + def _schedule_trial_pause(self, trial: Trial, should_checkpoint: bool = True): + if trial not in self._trial_to_actor: + logger.debug( + f"Trial PAUSE requested for trial {trial} but trial is already " + f"stopping. Ignoring." + ) + return + + if should_checkpoint: + self._cached_trial_decisions[trial.trial_id] = TrialScheduler.PAUSE + self._schedule_trial_save(trial=trial) + else: + self._schedule_trial_stop(trial) + self._set_trial_status(trial, Trial.PAUSED) + + ### + # TRAIN + + def _schedule_trial_train(self, trial: Trial): + args = () + method_name = "train" + + buffer_length, buffer_time_s = self._maybe_buffer_training(trial) + + if buffer_length > 1: + method_name = "train_buffered" + args = (buffer_length, buffer_time_s) + + logger.debug(f"Scheduling future {method_name.upper()} for trial {trial}") + + self._schedule_trial_task( + trial=trial, + method_name=method_name, + args=args, + on_result=self._on_training_result, + on_error=self._trial_task_failure, + ) + + def _maybe_buffer_training(self, trial: Trial) -> Tuple[int, float]: + buffer_time_s = max( + self._buffer_min_time_s, + min(self._buffer_max_time_s, self._actor_manager.num_actor_tasks // 10), + ) + buffer_length = self._buffer_length + + if buffer_length > 1 and trial.checkpoint_at_end: + # If a trial checkpoint can be triggered externally, + # it is not safe to buffer results. + if log_once("trial_executor_buffer_checkpoint"): + logger.warning( + "Disabling buffered training as you passed " + "`checkpoint_at_end` to `train.CheckpointConfig()`." + ) + return 1, buffer_time_s + + if buffer_length > 1 and trial.checkpoint_freq > 0: + return min(buffer_length, trial.checkpoint_freq), buffer_time_s + + return buffer_length, buffer_time_s + + ### + # RESULT + + def _on_training_result(self, trial, result): + if not isinstance(result, list): + result = [result] + with warn_if_slow("process_trial_result"): + self._process_trial_results(trial, result) + self._maybe_execute_queued_decision(trial, after_save=False) + + def _process_trial_results(self, trial, results): + logger.debug(f"Processing trial results for trial {trial}: {results}") + with warn_if_slow( + "process_trial_results", + message="Processing trial results took {duration:.3f} s, " + "which may be a performance bottleneck. Please consider " + "reporting results less frequently to Ray Tune.", + ): + for i, result in enumerate(results): + with warn_if_slow("process_trial_result"): + decision = self._process_trial_result(trial, result) + if decision is None: + # If we didn't get a decision, this means a + # non-training future (e.g. a save) was scheduled. + # We do not allow processing more results then. + if i < len(results) - 1: + if log_once("tune_controller_buffer_checkpoint"): + logger.warning( + f"Trial {trial} has a non-training future " + f"scheduled but {len(results) - i} results " + f"left to process. This means that a " + f"checkpoint was requested, but buffered " + f"training was continued before it was " + f"saved. Consider using non-buffered " + f"training by setting the env variable " + f"`TUNE_RESULT_BUFFER_LENGTH=1`." + ) + elif decision == TrialScheduler.STOP: + # If the decision is to stop the trial, + # ignore all results that came after that. + break + + def _process_trial_result(self, trial, result): + result.update(trial_id=trial.trial_id) + is_duplicate = RESULT_DUPLICATE in result + force_checkpoint = result.get(SHOULD_CHECKPOINT, False) + # TrialScheduler and SearchAlgorithm still receive a + # notification because there may be special handling for + # the `on_trial_complete` hook. + if is_duplicate: + logger.debug("Trial finished without logging 'done'.") + result = trial.last_result + result.update(done=True) + + self._total_time += result.get(TIME_THIS_ITER_S, 0) + + flat_result = flatten_dict(result) + self._validate_result_metrics(flat_result) + + if self._stopper(trial.trial_id, result) or trial.should_stop(flat_result): + decision = TrialScheduler.STOP + else: + with warn_if_slow("scheduler.on_trial_result"): + decision = self._scheduler_alg.on_trial_result( + self._wrapped(), trial, flat_result + ) + if decision == TrialScheduler.STOP: + result.update(done=True) + else: + # Only updating search alg if the trial is not to be stopped. + with warn_if_slow("search_alg.on_trial_result"): + self._search_alg.on_trial_result(trial.trial_id, flat_result) + + # If this is not a duplicate result, the callbacks should + # be informed about the result. + if not is_duplicate: + with warn_if_slow("callbacks.on_trial_result"): + self._callbacks.on_trial_result( + iteration=self._iteration, + trials=self._trials, + trial=trial, + result=result.copy(), + ) + trial.update_last_result(result) + # Include in next experiment checkpoint + self._mark_trial_to_checkpoint(trial) + + # Checkpoints to disk. This should be checked even if + # the scheduler decision is STOP or PAUSE. Note that + # PAUSE only checkpoints to memory and does not update + # the global checkpoint state. + if decision != TrialScheduler.PAUSE: + # TODO(justinvyu): This is a temporary hack to fix pausing trials. + # We already schedule a save task in `pause_trial`, so no need + # to do it again here. + self._checkpoint_trial_if_needed(trial, force=force_checkpoint) + + if trial.is_saving: + logger.debug(f"Caching trial decision for trial {trial}: {decision}") + # Cache decision to execute on after the save is processed. + # This prevents changing the trial's state or kicking off + # another training step prematurely. + if not self._cached_trial_decisions.get(trial.trial_id) or decision in { + TrialScheduler.PAUSE, + TrialScheduler.STOP, + }: + # If already set, only overwrite if it's a PAUSE or STOP. This is + # to avoid that CONTINUE decisions from a training step that resolve + # late overwrite PAUSE/STOP decision. + self._cached_trial_decisions[trial.trial_id] = decision + return None + else: + self._queue_decision(trial, decision) + return decision + + def _validate_result_metrics(self, result): + """ + Check if any of the required metrics was not reported + in the last result. If the only items are ``done`` or any of + DEBUG_METRICS, this means that no result was ever received and + the trial just returned. This is also okay and will not raise + an error. + + This will ignore checking for the DEFAULT_METRIC. + """ + if int(os.environ.get("TUNE_DISABLE_STRICT_METRIC_CHECKING", 0)) != 1 and ( + len({k for k in result if k not in list(DEBUG_METRICS) + [DONE]}) > 1 + ): + base_metric = self._metric if self._metric != DEFAULT_METRIC else None + scheduler_metric = ( + self._scheduler_alg.metric + if self._scheduler_alg.metric != DEFAULT_METRIC + else None + ) + search_metrics = ( + self._search_alg.metric + if self._search_alg.metric != DEFAULT_METRIC + else None + ) + + if isinstance(search_metrics, str): + search_metrics = [search_metrics] + + if base_metric and base_metric not in result: + report_metric = base_metric + location = "tune.TuneConfig()" + elif scheduler_metric and scheduler_metric not in result: + report_metric = scheduler_metric + location = type(self._scheduler_alg).__name__ + elif search_metrics and any( + search_metric not in result for search_metric in search_metrics + ): + report_metric = list( + filter( + lambda search_metric: search_metric not in result, + search_metrics, + ) + ) + if len(report_metric) == 1: + report_metric = report_metric[0] + location = type(self._search_alg).__name__ + else: + report_metric = None + location = None + + if report_metric: + raise ValueError( + "Trial returned a result which did not include the " + "specified metric(s) `{}` that `{}` expects. " + "Make sure your calls to `tune.report()` include the " + "metric, or set the " + "TUNE_DISABLE_STRICT_METRIC_CHECKING " + "environment variable to 1. Result: {}".format( + report_metric, location, result + ) + ) + + ### + # SAVE + def _schedule_trial_save( + self, + trial: Trial, + result: Optional[Dict] = None, + ) -> Optional[_FutureTrainingResult]: + if trial not in self._trial_to_actor: + logger.debug( + f"Trial SAVE requested for trial {trial} but trial is already " + f"stopping. Ignoring." + ) + return None + + result = result or trial.last_result + + future = self._schedule_trial_task( + trial=trial, + method_name="save", + on_result=self._on_saving_result, + on_error=self._trial_task_failure, + _return_future=True, + ) + # TODO(justinvyu): `trial.saving_to` (and trial.is_saving) is needed + # in order to prevent a done=True result from executing a STOP decision + # (which clears all futures) before the save gets processed. + # Keep this in for now while `train` and `save` are 2 separate steps. + trial.temporary_state.saving_to = _FutureTrainingResult(future) + + # `trial.saving_to` holds a future training result -- this is only used + # in the case of PBT to block until the checkpoint is ready. + # In all other situations, the checkpoint future is processed by the + # actor event manager when it is ready. + return trial.temporary_state.saving_to + + def _on_saving_result(self, trial, checkpoint_value: _TrainingResult): + with warn_if_slow("process_trial_save"): + self._process_trial_save(trial, checkpoint_value) + + with warn_if_slow("callbacks.on_trial_save"): + self._callbacks.on_trial_save( + iteration=self._iteration, trials=self._trials, trial=trial + ) + + self._maybe_execute_queued_decision(trial, after_save=True) + + def _process_trial_save(self, trial: Trial, checkpoint_value: _TrainingResult): + """Processes a trial save. + + Acts on the decision cached during the last `_process_trial` call. + + Args: + trial: Trial being saved. + """ + logger.debug("Trial %s: Processing trial save.", trial) + + try: + if not checkpoint_value.checkpoint: + logger.debug(f"Got empty checkpoint for trial {trial}") + else: + try: + self._callbacks.on_checkpoint( + iteration=self._iteration, + trials=self._trials, + trial=trial, + checkpoint=checkpoint_value.checkpoint, + ) + except Exception: + logger.warning( + "Error encountered during processing of callbacks. " + "Ray Train/Tune recently changed the checkpoint interface " + "that is passed to callbacks. If you implemented your own " + "callback with an `on_checkpoint` handler, please review " + "the checkpoint interface and adjust your code " + "accordingly." + ) + raise + + trial.on_checkpoint(checkpoint_value) + + self._checkpoint_manager.on_trial_checkpoint(trial) + + self._mark_trial_to_checkpoint(trial) + except Exception: + logger.exception( + "Trial %s: Error handling checkpoint %s", trial, checkpoint_value + ) + + trial.temporary_state.saving_to = None + decision = self._cached_trial_decisions.pop(trial.trial_id, None) + if decision and checkpoint_value: + self._queue_decision(trial, decision) + + def _checkpoint_trial_if_needed(self, trial, force=False): + """Checkpoints trial based off trial.last_result.""" + if trial.should_checkpoint() or force: + # Save trial runtime if possible. + if trial.temporary_state.ray_actor: + self._schedule_trial_save(trial) + + ### + # RESTORE + def _schedule_trial_restore(self, trial: Trial) -> bool: + checkpoint_result = trial.latest_checkpoint_result + + if not checkpoint_result: + logger.debug(f"Not restoring trial {trial}: No checkpoint found.") + return False + + # TODO(justinvyu): Is this really needed? + trial.temporary_state.restoring_from = checkpoint_result + + method_name = "restore" + args = (checkpoint_result,) + self._schedule_trial_task( + trial=trial, + method_name=method_name, + args=args, + kwargs={}, + on_result=self._on_restoring_result, + on_error=self._trial_task_failure, + ) + return True + + def _on_restoring_result(self, trial: Trial, result: Any): + self._process_trial_restore(trial) + + def _process_trial_restore(self, trial: Trial): + """Processes a trial restore. + + Args: + trial: Trial being restored. + """ + logger.debug("Trial %s: Processing trial restore.", trial) + trial.on_restore() + logger.debug("Trial %s: Restore processed successfully", trial) + self._set_trial_status(trial, Trial.RUNNING) + self._schedule_trial_train(trial) + self._live_trials.add(trial) + + def _try_recover( + self, trial: Trial, exc: Union[TuneError, RayTaskError, RayActorError] + ): + """Tries to recover trial. + + Notifies SearchAlgorithm and Scheduler if failure to recover. + + Args: + trial: Trial to recover. + exc: Exception prior to invoking this method. + """ + self._cached_trial_decisions.pop(trial.trial_id, None) + # Resetting this, in case that the trial is in saving status when it crashes. + if trial.is_saving: + trial.temporary_state.saving_to = None + self._schedule_trial_stop(trial, exception=exc) + + logger.debug("Trial %s: Notifying Scheduler and requeueing.", trial) + self._requeue_trial(trial) + + def _requeue_trial(self, trial): + """Notification to TrialScheduler and requeue trial. + + This does not notify the SearchAlgorithm because the function + evaluation is still in progress. + + """ + self._scheduler_alg.on_trial_error(self, trial) + self._set_trial_status(trial, status=Trial.PENDING) + + # TODO(rliaw): Right now, this pushes the trial to the end of queue + # because restoration can be expensive. However, this is not + # ideal since it just hides the issue - a better fix would + # be to use an actor table to detect the IP of the Trainable + # and rsync the files there. + # See https://github.com/ray-project/ray/issues/5168 + self._trials.pop(self._trials.index(trial)) + self._trials.append(trial) + self._live_trials.add(trial) + + with warn_if_slow("scheduler.on_trial_add"): + self._scheduler_alg.on_trial_add(self._wrapped(), trial) + + ### + # EXPORT + def _schedule_trial_export(self, trial: Trial): + if not trial.export_formats or len(trial.export_formats) <= 0: + return + + # Todo: We are waiting here synchronously until the task resolved. + # Instead, we should schedule the trial stop after the export resolved. + # This requires changes in TrialRunner, which we can remove once the + # legacy execution path has been removed. + future = self._schedule_trial_task( + trial=trial, + method_name="export_model", + args=(trial.export_formats,), + on_result=None, + on_error=self._trial_task_failure, + _return_future=True, + ) + self._actor_manager._actor_task_events.resolve_future(future) + + ### + # RESET + def _schedule_trial_reset( + self, + trial: Trial, + new_config: Dict, + new_experiment_tag: str, + ): + trial.set_experiment_tag(new_experiment_tag) + trial.set_config(new_config) + + # Pass magic variables + extra_config = copy.deepcopy(new_config) + extra_config[TRIAL_INFO] = _TrialInfo(trial) + + stdout_file, stderr_file = trial.log_to_file + extra_config[STDOUT_FILE] = stdout_file + extra_config[STDERR_FILE] = stderr_file + + logger_creator = partial( + _noop_logger_creator, logdir=trial.storage.trial_working_directory + ) + + self._resetting_trials.add(trial) + self._schedule_trial_task( + trial=trial, + method_name="reset", + args=(extra_config,), + kwargs={ + "logger_creator": logger_creator, + "storage": trial.storage, + }, + on_result=self._on_trial_reset, + on_error=self._trial_task_failure, + ) + + def _on_trial_reset(self, trial: Trial, success: bool): + self._resetting_trials.remove(trial) + + if not success: + info = ( + "Trainable runner reuse requires reset_config() to be " + "implemented and return True." + ) + + logger.error(f"Could not re-use actor for trial {trial}: {info}") + + exception = _AbortTrialExecution(info) + + trial.handle_error(exception) + self._schedule_trial_stop(trial, exception=exception) + return + + tracked_actor = self._trial_to_actor[trial] + + self._actor_started(tracked_actor, log="REUSED") + + def request_stop_trial(self, trial): + self._stop_queue.append(trial) + + def request_stop_experiment(self): + self._should_stop_experiment = True + + def _process_stop_requests(self): + while self._stop_queue: + t = self._stop_queue.pop() + self.stop_trial(t) + + def pause_trial(self, trial: Trial, should_checkpoint: bool = True): + """Pause a trial and reset the necessary state variables for resuming later. + + Args: + trial: Trial to pause. + should_checkpoint: Whether or not an in-memory checkpoint should be created + for this paused trial. Defaults to True. + """ + # NOTE: The cached trial decision is not needed since we will overrule this + # decision with PAUSE. + self._cached_trial_decisions.pop(trial.trial_id, None) + self._schedule_trial_pause(trial, should_checkpoint=should_checkpoint) + + def cleanup(self): + """Cleanup trials and callbacks.""" + self._cleanup_trials() + self.end_experiment_callbacks() + + def __getstate__(self): + """Gets state for trial. + + Note that this is not used as a pickling override as + does not have all fields. + """ + state = self.__dict__.copy() + for k in [ + "_trials", + "_live_trials", + "_stop_queue", + "_search_alg", + "_placeholder_resolvers", + "_scheduler_alg", + "_pending_trial_queue_times", + "_callbacks", + "_checkpoint_manager", + "_storage", + "_insufficient_resources_manager", + "_actor_manager", + "_class_cache", + "_resource_updater", + "_trials_to_cache", + "_trial_metadata", + "_actor_to_trial", + "_trial_to_actor", + "_resources_to_pending_trials", + "_pending_trials", + "_pending_trials_list", + "_running_trials", + "_paused_trials", + "_stopped_trials", + "_failed_trials", + "_resetting_trials", + "_started_actors", + "_stopping_actors", + "_staged_trials", + "_actor_cache", + ]: + del state[k] + return state + + def __setstate__(self, state): + # Use session_str from previous checkpoint if does not exist + session_str = state.pop("_session_str") + self.__dict__.setdefault("_session_str", session_str) + # Use start_time from previous checkpoint if does not exist + start_time = state.pop("_start_time") + self.__dict__.setdefault("_start_time", start_time) + + self.__dict__.update(state) + self._checkpoint_manager = self._create_checkpoint_manager() + + +class _TrialExecutorWrapper: + """Wraps around TrialExecutor class, intercepts API calls and warns users + of restricted API access. + + This is meant to facilitate restricting + the current API exposure of TrialExecutor by TrialScheduler. + """ + + def __init__( + self, + trial_executor: "_FakeRayTrialExecutor", + whitelist_attr: Optional[set] = None, + ): + self._trial_executor = trial_executor + self._whitelist_attr = whitelist_attr or set() + + for attr in self._whitelist_attr: + assert hasattr(self._trial_executor, attr) + + def __getattr__(self, attr): + if attr not in self._whitelist_attr: + if log_once("restrict_accessing_trial_executor"): + logger.warning( + f"You are trying to access {attr} interface of " + f"TrialExecutor in TrialScheduler, which is being " + f"restricted. If you believe it is reasonable for " + f"your scheduler to access this TrialExecutor API, " + f"please reach out to Ray team on GitHub. A more " + f"strict API access pattern would be enforced " + f"starting 1.12.0" + ) + return getattr(self._trial_executor, attr) + + +@DeveloperAPI +class TrialRunnerWrapper: + """Wraps around TrialRunner class, intercepts API calls and warns users + of restricted API access. + + This is meant to facilitate restricting + the current API exposure of TrialRunner by TrialScheduler. + """ + + _EXECUTOR_ATTR = "trial_executor" + + def __init__( + self, + tune_controller: TuneController, + trial_executor: Any, + runner_whitelist_attr: Optional[set] = None, + executor_whitelist_attr: Optional[set] = None, + ): + self._tune_controller = tune_controller + self._trial_executor = _TrialExecutorWrapper( + trial_executor, executor_whitelist_attr + ) + self._runner_whitelist_attr = runner_whitelist_attr or set() + + for attr in self._runner_whitelist_attr: + assert hasattr(self, attr) + + def __getattr__(self, attr): + if attr == self._EXECUTOR_ATTR: + return self._trial_executor + if attr not in self._runner_whitelist_attr: + if log_once("restrict_accessing_tune_controller"): + logger.warning( + f"You are trying to access {attr} interface of " + f"TrialRunner in TrialScheduler, which is being " + f"restricted. If you believe it is reasonable for " + f"your scheduler to access this TrialRunner API, " + f"please reach out to Ray team on GitHub. A more " + f"strict API access pattern would be enforced " + f"starting 1.12s.0" + ) + return getattr(self._tune_controller, attr) + + +def _get_max_pending_trials(search_alg: SearchAlgorithm) -> int: + max_pending_trials = os.getenv("TUNE_MAX_PENDING_TRIALS_PG", "auto") + + if max_pending_trials != "auto": + return int(max_pending_trials) + + # Else, auto detect. + + # Only BasicVariantGenerator supports > 1 pending trials. + # This is because we don't want to generate too many trials + # before we fit the searcher model. + if not isinstance(search_alg, BasicVariantGenerator): + return 1 + + # Allow up to at least 200 pending trials to trigger fast autoscaling + min_autoscaling_rate = 200 + + # Allow more pending trials for larger clusters (based on number of CPUs) + cluster_cpus = ray.cluster_resources().get("CPU", 1.0) + max_pending_trials = max(min_autoscaling_rate, int(cluster_cpus * 1.1)) + + if max_pending_trials > min_autoscaling_rate: + logger.warning( + f"The maximum number of pending trials has been " + f"automatically set to the number of available " + f"cluster CPUs, which is high " + f"({max_pending_trials} CPUs/pending trials). " + f"If you're running an experiment with a large number " + f"of trials, this could lead to scheduling overhead. " + f"In this case, consider setting the " + f"`TUNE_MAX_PENDING_TRIALS_PG` environment variable " + f"to the desired maximum number of concurrent pending trials." + ) + + return max_pending_trials + + +class _FakeRayTrialExecutor: + """The TuneController does not use a RayTrialExecutor anymore. + + Instead, we pass this fake executor for searchers/schedulers to use + as an interface. + + In the future, we should have the searchers/schedulers either interact with + the tune controller, or define a different API for more fine-grained scheduler + control. + """ + + def __init__(self, tune_controller: TuneController): + self._tune_controller = tune_controller + + def pause_trial(self, trial: Trial, should_checkpoint: bool = True): + return self._tune_controller._schedule_trial_pause( + trial, should_checkpoint=should_checkpoint + ) + + def save( + self, + trial: Trial, + result: Optional[Dict] = None, + ) -> Optional[_FutureTrainingResult]: + return self._tune_controller._schedule_trial_save(trial=trial, result=result) + + def has_resources_for_trial(self, trial: Trial): + return True + + @property + def _resource_updater(self): + return self._tune_controller._resource_updater + + def force_reconcilation_on_next_step_end(self): + pass diff --git a/.venv/lib/python3.11/site-packages/ray/tune/experiment/__init__.py b/.venv/lib/python3.11/site-packages/ray/tune/experiment/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..39a20fc56e8e49826cbc5c4cd705efe7eb4b3f4f --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/tune/experiment/__init__.py @@ -0,0 +1,4 @@ +from ray.tune.experiment.experiment import Experiment, _convert_to_experiment_list +from ray.tune.experiment.trial import Trial + +__all__ = ["Experiment", "_convert_to_experiment_list", "Trial"] diff --git a/.venv/lib/python3.11/site-packages/ray/tune/experiment/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/tune/experiment/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e90c021880ad7fd85dcc82cb12599c73ac2043f0 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/tune/experiment/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/tune/experiment/__pycache__/config_parser.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/tune/experiment/__pycache__/config_parser.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bb33bf20cb32258a8f03a5cf7b5025f9079edd70 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/tune/experiment/__pycache__/config_parser.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/tune/experiment/__pycache__/experiment.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/tune/experiment/__pycache__/experiment.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9f1777b20a5eefc264889282b5ac58a633467f07 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/tune/experiment/__pycache__/experiment.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/tune/experiment/__pycache__/trial.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/tune/experiment/__pycache__/trial.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4dcc89f8dc40bf27715de46ad345dab5fabed571 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/tune/experiment/__pycache__/trial.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/tune/experiment/config_parser.py b/.venv/lib/python3.11/site-packages/ray/tune/experiment/config_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..859f11402b9e4833f3656b0a8d9eca6e914059d6 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/tune/experiment/config_parser.py @@ -0,0 +1,210 @@ +import argparse +import json + +from ray.train import CheckpointConfig +from ray.tune.error import TuneError +from ray.tune.experiment import Trial +from ray.tune.resources import json_to_resources + +# For compatibility under py2 to consider unicode as str +from ray.tune.utils.serialization import TuneFunctionEncoder +from ray.tune.utils.util import SafeFallbackEncoder + + +def _make_parser(parser_creator=None, **kwargs): + """Returns a base argument parser for the ray.tune tool. + + Args: + parser_creator: A constructor for the parser class. + kwargs: Non-positional args to be passed into the + parser class constructor. + """ + + if parser_creator: + parser = parser_creator(**kwargs) + else: + parser = argparse.ArgumentParser(**kwargs) + + # Note: keep this in sync with rllib/train.py + parser.add_argument( + "--run", + default=None, + type=str, + help="The algorithm or model to train. This may refer to the name " + "of a built-on algorithm (e.g. RLlib's DQN or PPO), or a " + "user-defined trainable function or class registered in the " + "tune registry.", + ) + parser.add_argument( + "--stop", + default="{}", + type=json.loads, + help="The stopping criteria, specified in JSON. The keys may be any " + "field returned by 'train()' e.g. " + '\'{"time_total_s": 600, "training_iteration": 100000}\' to stop ' + "after 600 seconds or 100k iterations, whichever is reached first.", + ) + parser.add_argument( + "--config", + default="{}", + type=json.loads, + help="Algorithm-specific configuration (e.g. env, hyperparams), " + "specified in JSON.", + ) + parser.add_argument( + "--resources-per-trial", + default=None, + type=json_to_resources, + help="Override the machine resources to allocate per trial, e.g. " + '\'{"cpu": 64, "gpu": 8}\'. Note that GPUs will not be assigned ' + "unless you specify them here. For RLlib, you probably want to " + "leave this alone and use RLlib configs to control parallelism.", + ) + parser.add_argument( + "--num-samples", + default=1, + type=int, + help="Number of times to repeat each trial.", + ) + parser.add_argument( + "--checkpoint-freq", + default=0, + type=int, + help="How many training iterations between checkpoints. " + "A value of 0 (default) disables checkpointing.", + ) + parser.add_argument( + "--checkpoint-at-end", + action="store_true", + help="Whether to checkpoint at the end of the experiment. Default is False.", + ) + parser.add_argument( + "--keep-checkpoints-num", + default=None, + type=int, + help="Number of best checkpoints to keep. Others get " + "deleted. Default (None) keeps all checkpoints.", + ) + parser.add_argument( + "--checkpoint-score-attr", + default="training_iteration", + type=str, + help="Specifies by which attribute to rank the best checkpoint. " + "Default is increasing order. If attribute starts with min- it " + "will rank attribute in decreasing order. Example: " + "min-validation_loss", + ) + parser.add_argument( + "--export-formats", + default=None, + help="List of formats that exported at the end of the experiment. " + "Default is None. For RLlib, 'checkpoint' and 'model' are " + "supported for TensorFlow policy graphs.", + ) + parser.add_argument( + "--max-failures", + default=3, + type=int, + help="Try to recover a trial from its last checkpoint at least this " + "many times. Only applies if checkpointing is enabled.", + ) + parser.add_argument( + "--scheduler", + default="FIFO", + type=str, + help="FIFO (default), MedianStopping, AsyncHyperBand, " + "HyperBand, or HyperOpt.", + ) + parser.add_argument( + "--scheduler-config", + default="{}", + type=json.loads, + help="Config options to pass to the scheduler.", + ) + + # Note: this currently only makes sense when running a single trial + parser.add_argument( + "--restore", + default=None, + type=str, + help="If specified, restore from this checkpoint.", + ) + + return parser + + +def _to_argv(config): + """Converts configuration to a command line argument format.""" + argv = [] + for k, v in config.items(): + if "-" in k: + raise ValueError("Use '_' instead of '-' in `{}`".format(k)) + if v is None: + continue + if not isinstance(v, bool) or v: # for argparse flags + argv.append("--{}".format(k.replace("_", "-"))) + if isinstance(v, str): + argv.append(v) + elif isinstance(v, bool): + pass + elif callable(v): + argv.append(json.dumps(v, cls=TuneFunctionEncoder)) + else: + argv.append(json.dumps(v, cls=SafeFallbackEncoder)) + return argv + + +_cached_pgf = {} + + +def _create_trial_from_spec( + spec: dict, parser: argparse.ArgumentParser, **trial_kwargs +): + """Creates a Trial object from parsing the spec. + + Args: + spec: A resolved experiment specification. Arguments should + The args here should correspond to the command line flags + in ray.tune.experiment.config_parser. + parser: An argument parser object from + make_parser. + trial_kwargs: Extra keyword arguments used in instantiating the Trial. + + Returns: + A trial object with corresponding parameters to the specification. + """ + global _cached_pgf + + spec = spec.copy() + resources = spec.pop("resources_per_trial", None) + + try: + args, _ = parser.parse_known_args(_to_argv(spec)) + except SystemExit: + raise TuneError("Error parsing args, see above message", spec) + + if resources: + trial_kwargs["placement_group_factory"] = resources + + checkpoint_config = spec.get("checkpoint_config", CheckpointConfig()) + + return Trial( + # Submitting trial via server in py2.7 creates Unicode, which does not + # convert to string in a straightforward manner. + trainable_name=spec["run"], + # json.load leads to str -> unicode in py2.7 + config=spec.get("config", {}), + # json.load leads to str -> unicode in py2.7 + stopping_criterion=spec.get("stop", {}), + checkpoint_config=checkpoint_config, + export_formats=spec.get("export_formats", []), + # str(None) doesn't create None + restore_path=spec.get("restore"), + trial_name_creator=spec.get("trial_name_creator"), + trial_dirname_creator=spec.get("trial_dirname_creator"), + log_to_file=spec.get("log_to_file"), + # str(None) doesn't create None + max_failures=args.max_failures, + storage=spec.get("storage"), + **trial_kwargs, + ) diff --git a/.venv/lib/python3.11/site-packages/ray/tune/experiment/experiment.py b/.venv/lib/python3.11/site-packages/ray/tune/experiment/experiment.py new file mode 100644 index 0000000000000000000000000000000000000000..30a8a2fd6fc2550f2fba1b6eb67c984fd972da2c --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/tune/experiment/experiment.py @@ -0,0 +1,445 @@ +import copy +import datetime +import logging +import pprint as pp +import traceback +from functools import partial +from pathlib import Path +from pickle import PicklingError +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + List, + Mapping, + Optional, + Sequence, + Type, + Union, +) + +import ray +from ray.exceptions import RpcError +from ray.train import CheckpointConfig, SyncConfig +from ray.train._internal.storage import StorageContext +from ray.train.constants import DEFAULT_STORAGE_PATH +from ray.tune.error import TuneError +from ray.tune.registry import is_function_trainable, register_trainable +from ray.tune.stopper import CombinedStopper, FunctionStopper, Stopper, TimeoutStopper +from ray.util.annotations import Deprecated, DeveloperAPI + +if TYPE_CHECKING: + import pyarrow.fs + + from ray.tune import PlacementGroupFactory + from ray.tune.experiment import Trial + + +logger = logging.getLogger(__name__) + + +def _validate_log_to_file(log_to_file): + """Validate ``train.RunConfig``'s ``log_to_file`` parameter. Return + validated relative stdout and stderr filenames.""" + if not log_to_file: + stdout_file = stderr_file = None + elif isinstance(log_to_file, bool) and log_to_file: + stdout_file = "stdout" + stderr_file = "stderr" + elif isinstance(log_to_file, str): + stdout_file = stderr_file = log_to_file + elif isinstance(log_to_file, Sequence): + if len(log_to_file) != 2: + raise ValueError( + "If you pass a Sequence to `log_to_file` it has to have " + "a length of 2 (for stdout and stderr, respectively). The " + "Sequence you passed has length {}.".format(len(log_to_file)) + ) + stdout_file, stderr_file = log_to_file + else: + raise ValueError( + "You can pass a boolean, a string, or a Sequence of length 2 to " + "`log_to_file`, but you passed something else ({}).".format( + type(log_to_file) + ) + ) + return stdout_file, stderr_file + + +@DeveloperAPI +class Experiment: + """Tracks experiment specifications. + + Implicitly registers the Trainable if needed. The args here take + the same meaning as the arguments defined `tune.py:run`. + + .. code-block:: python + + experiment_spec = Experiment( + "my_experiment_name", + my_func, + stop={"mean_accuracy": 100}, + config={ + "alpha": tune.grid_search([0.2, 0.4, 0.6]), + "beta": tune.grid_search([1, 2]), + }, + resources_per_trial={ + "cpu": 1, + "gpu": 0 + }, + num_samples=10, + local_dir="~/ray_results", + checkpoint_freq=10, + max_failures=2) + + """ + + # Keys that will be present in `public_spec` dict. + PUBLIC_KEYS = {"stop", "num_samples", "time_budget_s"} + _storage_context_cls = StorageContext + + def __init__( + self, + name: str, + run: Union[str, Callable, Type], + *, + stop: Optional[Union[Mapping, Stopper, Callable[[str, Mapping], bool]]] = None, + time_budget_s: Optional[Union[int, float, datetime.timedelta]] = None, + config: Optional[Dict[str, Any]] = None, + resources_per_trial: Union[ + None, Mapping[str, Union[float, int, Mapping]], "PlacementGroupFactory" + ] = None, + num_samples: int = 1, + storage_path: Optional[str] = None, + storage_filesystem: Optional["pyarrow.fs.FileSystem"] = None, + sync_config: Optional[Union[SyncConfig, dict]] = None, + checkpoint_config: Optional[Union[CheckpointConfig, dict]] = None, + trial_name_creator: Optional[Callable[["Trial"], str]] = None, + trial_dirname_creator: Optional[Callable[["Trial"], str]] = None, + log_to_file: bool = False, + export_formats: Optional[Sequence] = None, + max_failures: int = 0, + restore: Optional[str] = None, + # Deprecated + local_dir: Optional[str] = None, + ): + if isinstance(checkpoint_config, dict): + checkpoint_config = CheckpointConfig(**checkpoint_config) + else: + checkpoint_config = checkpoint_config or CheckpointConfig() + + if is_function_trainable(run): + if checkpoint_config.checkpoint_at_end: + raise ValueError( + "'checkpoint_at_end' cannot be used with a function trainable. " + "You should include one last call to " + "`ray.train.report(metrics=..., checkpoint=...)` " + "at the end of your training loop to get this behavior." + ) + if checkpoint_config.checkpoint_frequency: + raise ValueError( + "'checkpoint_frequency' cannot be set for a function trainable. " + "You will need to report a checkpoint every " + "`checkpoint_frequency` iterations within your training loop using " + "`ray.train.report(metrics=..., checkpoint=...)` " + "to get this behavior." + ) + try: + self._run_identifier = Experiment.register_if_needed(run) + except RpcError as e: + if e.rpc_code == ray._raylet.GRPC_STATUS_CODE_RESOURCE_EXHAUSTED: + raise TuneError( + f"The Trainable/training function is too large for grpc resource " + f"limit. Check that its definition is not implicitly capturing a " + f"large array or other object in scope. " + f"Tip: use tune.with_parameters() to put large objects " + f"in the Ray object store. \n" + f"Original exception: {traceback.format_exc()}" + ) + else: + raise e + + if not name: + name = StorageContext.get_experiment_dir_name(run) + + storage_path = storage_path or DEFAULT_STORAGE_PATH + self.storage = self._storage_context_cls( + storage_path=storage_path, + storage_filesystem=storage_filesystem, + sync_config=sync_config, + experiment_dir_name=name, + ) + logger.debug(f"StorageContext on the DRIVER:\n{self.storage}") + + config = config or {} + if not isinstance(config, dict): + raise ValueError( + f"`Experiment(config)` must be a dict, got: {type(config)}. " + "Please convert your search space to a dict before passing it in." + ) + + self._stopper = None + stopping_criteria = {} + if not stop: + pass + elif isinstance(stop, list): + bad_stoppers = [s for s in stop if not isinstance(s, Stopper)] + if bad_stoppers: + stopper_types = [type(s) for s in stop] + raise ValueError( + "If you pass a list as the `stop` argument to " + "`train.RunConfig()`, each element must be an instance of " + f"`tune.stopper.Stopper`. Got {stopper_types}." + ) + self._stopper = CombinedStopper(*stop) + elif isinstance(stop, dict): + stopping_criteria = stop + elif callable(stop): + if FunctionStopper.is_valid_function(stop): + self._stopper = FunctionStopper(stop) + elif isinstance(stop, Stopper): + self._stopper = stop + else: + raise ValueError( + "Provided stop object must be either a dict, " + "a function, or a subclass of " + f"`ray.tune.Stopper`. Got {type(stop)}." + ) + else: + raise ValueError( + f"Invalid stop criteria: {stop}. Must be a " + f"callable or dict. Got {type(stop)}." + ) + + if time_budget_s: + if self._stopper: + self._stopper = CombinedStopper( + self._stopper, TimeoutStopper(time_budget_s) + ) + else: + self._stopper = TimeoutStopper(time_budget_s) + + stdout_file, stderr_file = _validate_log_to_file(log_to_file) + + spec = { + "run": self._run_identifier, + "stop": stopping_criteria, + "time_budget_s": time_budget_s, + "config": config, + "resources_per_trial": resources_per_trial, + "num_samples": num_samples, + "checkpoint_config": checkpoint_config, + "trial_name_creator": trial_name_creator, + "trial_dirname_creator": trial_dirname_creator, + "log_to_file": (stdout_file, stderr_file), + "export_formats": export_formats or [], + "max_failures": max_failures, + "restore": ( + Path(restore).expanduser().absolute().as_posix() if restore else None + ), + "storage": self.storage, + } + self.spec = spec + + @classmethod + def from_json(cls, name: str, spec: dict): + """Generates an Experiment object from JSON. + + Args: + name: Name of Experiment. + spec: JSON configuration of experiment. + """ + if "run" not in spec: + raise TuneError("No trainable specified!") + + # Special case the `env` param for RLlib by automatically + # moving it into the `config` section. + if "env" in spec: + spec["config"] = spec.get("config", {}) + spec["config"]["env"] = spec["env"] + del spec["env"] + + if "sync_config" in spec and isinstance(spec["sync_config"], dict): + spec["sync_config"] = SyncConfig(**spec["sync_config"]) + + if "checkpoint_config" in spec and isinstance(spec["checkpoint_config"], dict): + spec["checkpoint_config"] = CheckpointConfig(**spec["checkpoint_config"]) + + spec = copy.deepcopy(spec) + + run_value = spec.pop("run") + try: + exp = cls(name, run_value, **spec) + except TypeError as e: + raise TuneError( + f"Failed to load the following Tune experiment " + f"specification:\n\n {pp.pformat(spec)}.\n\n" + f"Please check that the arguments are valid. " + f"Experiment creation failed with the following " + f"error:\n {e}" + ) + return exp + + @classmethod + def get_trainable_name(cls, run_object: Union[str, Callable, Type]): + """Get Trainable name. + + Args: + run_object: Trainable to run. If string, + assumes it is an ID and does not modify it. Otherwise, + returns a string corresponding to the run_object name. + + Returns: + A string representing the trainable identifier. + + Raises: + TuneError: if ``run_object`` passed in is invalid. + """ + from ray.tune.search.sample import Domain + + if isinstance(run_object, str) or isinstance(run_object, Domain): + return run_object + elif isinstance(run_object, type) or callable(run_object): + name = "DEFAULT" + if hasattr(run_object, "_name"): + name = run_object._name + elif hasattr(run_object, "__name__"): + fn_name = run_object.__name__ + if fn_name == "": + name = "lambda" + elif fn_name.startswith("<"): + name = "DEFAULT" + else: + name = fn_name + elif ( + isinstance(run_object, partial) + and hasattr(run_object, "func") + and hasattr(run_object.func, "__name__") + ): + name = run_object.func.__name__ + else: + logger.warning("No name detected on trainable. Using {}.".format(name)) + return name + else: + raise TuneError("Improper 'run' - not string nor trainable.") + + @classmethod + def register_if_needed(cls, run_object: Union[str, Callable, Type]): + """Registers Trainable or Function at runtime. + + Assumes already registered if run_object is a string. + Also, does not inspect interface of given run_object. + + Args: + run_object: Trainable to run. If string, + assumes it is an ID and does not modify it. Otherwise, + returns a string corresponding to the run_object name. + + Returns: + A string representing the trainable identifier. + """ + from ray.tune.search.sample import Domain + + if isinstance(run_object, str): + return run_object + elif isinstance(run_object, Domain): + logger.warning("Not registering trainable. Resolving as variant.") + return run_object + name = cls.get_trainable_name(run_object) + try: + register_trainable(name, run_object) + except (TypeError, PicklingError) as e: + extra_msg = ( + "Other options: " + "\n-Try reproducing the issue by calling " + "`pickle.dumps(trainable)`. " + "\n-If the error is typing-related, try removing " + "the type annotations and try again." + ) + raise type(e)(str(e) + " " + extra_msg) from None + return name + + @property + def stopper(self): + return self._stopper + + @property + def local_path(self) -> Optional[str]: + return self.storage.experiment_driver_staging_path + + @property + @Deprecated("Replaced by `local_path`") + def local_dir(self): + # TODO(justinvyu): [Deprecated] Remove in 2.11. + raise DeprecationWarning("Use `local_path` instead of `local_dir`.") + + @property + def remote_path(self) -> Optional[str]: + return self.storage.experiment_fs_path + + @property + def path(self) -> Optional[str]: + return self.remote_path or self.local_path + + @property + def checkpoint_config(self): + return self.spec.get("checkpoint_config") + + @property + @Deprecated("Replaced by `local_path`") + def checkpoint_dir(self): + # TODO(justinvyu): [Deprecated] Remove in 2.11. + raise DeprecationWarning("Use `local_path` instead of `checkpoint_dir`.") + + @property + def run_identifier(self): + """Returns a string representing the trainable identifier.""" + return self._run_identifier + + @property + def public_spec(self) -> Dict[str, Any]: + """Returns the spec dict with only the public-facing keys. + + Intended to be used for passing information to callbacks, + Searchers and Schedulers. + """ + return {k: v for k, v in self.spec.items() if k in self.PUBLIC_KEYS} + + +def _convert_to_experiment_list(experiments: Union[Experiment, List[Experiment], Dict]): + """Produces a list of Experiment objects. + + Converts input from dict, single experiment, or list of + experiments to list of experiments. If input is None, + will return an empty list. + + Arguments: + experiments: Experiments to run. + + Returns: + List of experiments. + """ + exp_list = experiments + + # Transform list if necessary + if experiments is None: + exp_list = [] + elif isinstance(experiments, Experiment): + exp_list = [experiments] + elif type(experiments) is dict: + exp_list = [ + Experiment.from_json(name, spec) for name, spec in experiments.items() + ] + + # Validate exp_list + if type(exp_list) is list and all(isinstance(exp, Experiment) for exp in exp_list): + if len(exp_list) > 1: + logger.info( + "Running with multiple concurrent experiments. " + "All experiments will be using the same SearchAlgorithm." + ) + else: + raise TuneError("Invalid argument: {}".format(experiments)) + + return exp_list diff --git a/.venv/lib/python3.11/site-packages/ray/tune/experiment/trial.py b/.venv/lib/python3.11/site-packages/ray/tune/experiment/trial.py new file mode 100644 index 0000000000000000000000000000000000000000..5ce42f808af48ef0baddf0ce75abba490342ab79 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/tune/experiment/trial.py @@ -0,0 +1,1073 @@ +import copy +import json +import logging +import os +import platform +import re +import time +import uuid +from contextlib import contextmanager +from functools import partial +from numbers import Number +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union + +import ray +import ray.cloudpickle as cloudpickle +from ray._private.utils import binary_to_hex, hex_to_binary +from ray.air.constants import ( + EXPR_ERROR_FILE, + EXPR_ERROR_PICKLE_FILE, + TRAINING_ITERATION, +) +from ray.exceptions import RayActorError, RayTaskError +from ray.train import Checkpoint, CheckpointConfig +from ray.train._internal.checkpoint_manager import _CheckpointManager +from ray.train._internal.session import _FutureTrainingResult, _TrainingResult +from ray.train._internal.storage import StorageContext, _exists_at_fs_path +from ray.train.constants import ( + RAY_CHDIR_TO_TRIAL_DIR, + RAY_TRAIN_COUNT_PREEMPTION_AS_FAILURE, +) +from ray.tune.error import TuneError +from ray.tune.execution.placement_groups import ( + PlacementGroupFactory, + resource_dict_to_pg_factory, +) +from ray.tune.logger import NoopLogger + +# NOTE(rkn): We import ray.tune.registry here instead of importing the names we +# need because there are cyclic imports that may cause specific names to not +# have been defined yet. See https://github.com/ray-project/ray/issues/1716. +from ray.tune.registry import get_trainable_cls, validate_trainable +from ray.tune.result import ( + DEBUG_METRICS, + DONE, + NODE_IP, + PID, + STDERR_FILE, + STDOUT_FILE, + TRIAL_ID, + TRIAL_INFO, +) +from ray.tune.trainable.metadata import _TrainingRunMetadata +from ray.tune.utils import date_str, flatten_dict +from ray.tune.utils.serialization import TuneFunctionDecoder, TuneFunctionEncoder +from ray.util import log_once +from ray.util.annotations import Deprecated, DeveloperAPI + +DEBUG_PRINT_INTERVAL = 5 +_DEFAULT_WIN_MAX_PATH_LENGTH = 260 +TRIAL_STATE_FILENAME = "trial_metadata.json" + + +logger = logging.getLogger(__name__) + + +class _Location: + """Describes the location at which Trial is placed to run.""" + + def __init__(self, hostname=None, pid=None): + self.hostname = hostname + self.pid = pid + + def __str__(self): + if not self.pid: + return "" + elif self.hostname == platform.node(): + return "pid={}".format(self.pid) + else: + return "{}:{}".format(self.hostname, self.pid) + + +@DeveloperAPI +class ExportFormat: + """Describes the format to import/export the trial Trainable. + + This may correspond to different file formats based on the + Trainable implementation. + """ + + CHECKPOINT = "checkpoint" + MODEL = "model" + ONNX = "onnx" + H5 = "h5" + + @staticmethod + def validate(formats): + """Validates formats. + + Raises: + ValueError: if the format is unknown. + """ + for i in range(len(formats)): + formats[i] = formats[i].strip().lower() + if formats[i] not in [ + ExportFormat.CHECKPOINT, + ExportFormat.MODEL, + ExportFormat.ONNX, + ExportFormat.H5, + ]: + raise TuneError("Unsupported import/export format: " + formats[i]) + + +class _TrialInfo: + """Serializable struct for holding information for a Trial. + + Attributes: + trial_name: String name of the current trial. + trial_id: trial_id of the trial + trial_resources: resources used by trial. + """ + + def __init__(self, trial: "Trial"): + self._trial_name = str(trial) + self._trial_id = trial.trial_id + self._trial_resources = trial.placement_group_factory + self._experiment_name = trial.experiment_dir_name + + @property + def experiment_name(self): + return self._experiment_name + + @property + def trial_name(self): + return self._trial_name + + @property + def trial_id(self): + return self._trial_id + + @property + def trial_resources(self) -> PlacementGroupFactory: + return self._trial_resources + + @trial_resources.setter + def trial_resources(self, new_resources: PlacementGroupFactory): + self._trial_resources = new_resources + + +class _TemporaryTrialState: + """Temporary trial state. + + Values saved here should not be restored on resume. + """ + + def __init__(self): + self.location = _Location() + + self.ray_actor: Optional[ray.actor.ActorHandle] = None + + self.saving_to: Optional[_FutureTrainingResult] = None + self.restoring_from: Optional[_TrainingResult] = None + + self.num_restore_failures: int = 0 + + def __getstate__(self): + return {} + + +def _get_max_path_length() -> int: + if hasattr(os, "pathconf"): + return os.pathconf("/", "PC_PATH_MAX") + # Windows + return _DEFAULT_WIN_MAX_PATH_LENGTH + + +def _create_unique_logdir_name(root: str, relative_logdir: str) -> str: + candidate = Path(root).expanduser().joinpath(relative_logdir) + if candidate.exists(): + relative_logdir_old = relative_logdir + relative_logdir += "_" + uuid.uuid4().hex[:4] + logger.info( + f"Creating a new dirname {relative_logdir} because " + f"trial dirname '{relative_logdir_old}' already exists." + ) + return relative_logdir + + +def _noop_logger_creator(config: Dict[str, Any], logdir: str): + # Upon remote process setup, record the actor's original working dir before + # changing to the Tune logdir + os.environ.setdefault("TUNE_ORIG_WORKING_DIR", os.getcwd()) + + os.makedirs(logdir, exist_ok=True) + + if bool(int(os.environ.get(RAY_CHDIR_TO_TRIAL_DIR, "1"))): + # Set the working dir to the trial directory in the remote process, + # for user file writes + if not ray._private.worker._mode() == ray._private.worker.LOCAL_MODE: + os.chdir(logdir) + + return NoopLogger(config, logdir) + + +def _get_trainable_kwargs(trial: "Trial") -> Dict[str, Any]: + trial.init_local_path() + + logger_creator = partial( + _noop_logger_creator, logdir=trial.storage.trial_working_directory + ) + + trial_config = copy.deepcopy(trial.config) + trial_config[TRIAL_INFO] = _TrialInfo(trial) + stdout_file, stderr_file = trial.log_to_file + trial_config[STDOUT_FILE] = stdout_file + trial_config[STDERR_FILE] = stderr_file + + assert trial.storage.trial_dir_name + + kwargs = { + "config": trial_config, + "logger_creator": logger_creator, + "storage": trial.storage, + } + + return kwargs + + +@contextmanager +def _change_working_directory(trial): + """Context manager changing working directory to trial logdir. + Used in local mode. + + For non-local mode it is no-op. + """ + if ray._private.worker._mode() == ray._private.worker.LOCAL_MODE: + old_dir = os.getcwd() + try: + os.chdir(trial.local_path) + yield + finally: + os.chdir(old_dir) + else: + yield + + +@DeveloperAPI +class Trial: + """A trial object holds the state for one model training run. + + Trials are themselves managed by the TrialRunner class, which implements + the event loop for submitting trial runs to a Ray cluster. + + Trials start in the PENDING state, and transition to RUNNING once started. + On error, it transitions to ERROR, otherwise TERMINATED on success. + + There are resources allocated to each trial. These should be specified + using ``PlacementGroupFactory``. + + Attributes: + trainable_name: Name of the trainable object to be executed. + config: Provided configuration dictionary with evaluated params. + trial_id: Unique identifier for the trial. + path: Path where results for this trial are stored. Can be on + the local node or on cloud storage. + local_path: Path on the local disk where results are stored. + remote_path: Path on cloud storage where results are stored, + or None if not set. + relative_logdir: Directory of the trial relative to its + experiment directory. + evaluated_params: Evaluated parameters by search algorithm, + experiment_tag: Identifying trial name to show in the console + status: One of PENDING, RUNNING, PAUSED, TERMINATED, ERROR/ + error_file: Path to the errors that this trial has raised. + + """ + + _nonjson_fields = [ + "results", + "extra_arg", + "placement_group_factory", + "_resources", + "_default_placement_group_factory", + ] + + PENDING = "PENDING" + RUNNING = "RUNNING" + PAUSED = "PAUSED" + TERMINATED = "TERMINATED" + ERROR = "ERROR" + + def __init__( + self, + trainable_name: str, + *, + config: Optional[Dict] = None, + trial_id: Optional[str] = None, + storage: Optional[StorageContext] = None, + evaluated_params: Optional[Dict] = None, + experiment_tag: str = "", + placement_group_factory: Optional[PlacementGroupFactory] = None, + stopping_criterion: Optional[Dict[str, float]] = None, + checkpoint_config: Optional[CheckpointConfig] = None, + export_formats: Optional[List[str]] = None, + restore_path: Optional[str] = None, + trial_name_creator: Optional[Callable[["Trial"], str]] = None, + trial_dirname_creator: Optional[Callable[["Trial"], str]] = None, + log_to_file: Union[Optional[str], Tuple[Optional[str], Optional[str]]] = None, + max_failures: int = 0, + stub: bool = False, + _setup_default_resource: bool = True, + ): + """Initialize a new trial. + + The args here take the same meaning as the command line flags defined + in ray.tune.experiment.config_parser. + + Args: + _setup_default_resource: Whether to set up default resources. + When initializing trials from checkpoints, this field is set to false, + so that setting up default resources can be delayed till after + ``trial.config`` is loaded from checkpoints. + """ + # If this is set, trainables are not validated or looked up. + # This can be used e.g. to initialize Trial objects from checkpoints + # without loading the trainable first. + self.stub = stub + + if not self.stub: + validate_trainable(trainable_name) + # Trial config + self.trainable_name = trainable_name + self.trial_id = Trial.generate_id() if trial_id is None else trial_id + + self.temporary_state = _TemporaryTrialState() + self.run_metadata = _TrainingRunMetadata() + + # Create a copy, since `init_local_path` updates the context with the + # generated trial dirname. + self.storage = copy.copy(storage) + + self.config = config or {} + # Save a copy of the original unresolved config so that we can swap + # out and update any reference config values after restoration. + self.__unresolved_config = self.config + + # Parameters that Tune varies across searches. + self.evaluated_params = evaluated_params or {} + self.experiment_tag = experiment_tag + self.stopping_criterion = stopping_criterion or {} + + self._setup_default_resource = _setup_default_resource + + if placement_group_factory and not isinstance( + placement_group_factory, PlacementGroupFactory + ): + placement_group_factory = resource_dict_to_pg_factory( + placement_group_factory + ) + + self._default_placement_group_factory = placement_group_factory + # Will be created in create_placement_group_factory(). + self.placement_group_factory = None + + self.log_to_file = log_to_file + # Make sure `stdout_file, stderr_file = Trial.log_to_file` works + if ( + not self.log_to_file + or not isinstance(self.log_to_file, Sequence) + or not len(self.log_to_file) == 2 + ): + self.log_to_file = (None, None) + + self.max_failures = max_failures + + # Local trial state that is updated during the run + self._default_result_or_future: Union[ray.ObjectRef, dict, None] = None + + self.export_formats = export_formats + self.status = Trial.PENDING + self.relative_logdir = None + + self.trial_name_creator = trial_name_creator + self.trial_dirname_creator = trial_dirname_creator + self.custom_trial_name = None + self.custom_dirname = None + + # Checkpoint config + checkpoint_config = checkpoint_config or CheckpointConfig() + + self.run_metadata.checkpoint_manager = _CheckpointManager( + checkpoint_config=checkpoint_config + ) + + # Restoration fields + self.restore_path = restore_path + self._restore_checkpoint_result: Optional[_TrainingResult] = None + if restore_path: + # tune.run(restore) passes in a path without metrics. + self._restore_checkpoint_result = _TrainingResult( + checkpoint=Checkpoint.from_directory(restore_path), metrics={} + ) + + if trial_name_creator: + self.custom_trial_name = trial_name_creator(self) + + if trial_dirname_creator: + self.custom_dirname = trial_dirname_creator(self) + if os.path.sep in self.custom_dirname: + raise ValueError( + f"Trial dirname must not contain '/'. Got {self.custom_dirname}" + ) + + self._state_json = None + + def create_placement_group_factory(self): + """Compute placement group factory if needed. + + Note: this must be called after all the placeholders in + self.config are resolved. + """ + trainable_cls = self.get_trainable_cls() + if not trainable_cls or not self._setup_default_resource: + # Create placement group factory using default resources. + self.placement_group_factory = ( + self._default_placement_group_factory or resource_dict_to_pg_factory() + ) + return + + default_resources = trainable_cls.default_resource_request(self.config) + + # If Trainable returns resources, do not allow manual override via + # `resources_per_trial` by the user. + if default_resources and self._default_placement_group_factory: + raise TuneError( + "Resources for {} have been automatically set to {} " + "by its `default_resource_request()` method. Please " + "clear the `resources_per_trial` option.".format( + trainable_cls, default_resources + ) + ) + + if default_resources and not isinstance( + default_resources, PlacementGroupFactory + ): + default_resources = resource_dict_to_pg_factory(default_resources) + + self.placement_group_factory = ( + # default_resource_request + default_resources + # resources_per_trial + or self._default_placement_group_factory + # cpu=1 + or resource_dict_to_pg_factory() + ) + + def _get_default_result_or_future(self) -> Optional[dict]: + """Calls ray.get on self._default_result_or_future and assigns back. + + Returns None in case of exceptions. + Will also set the trial location if runner is set. + """ + if self._default_result_or_future and isinstance( + self._default_result_or_future, ray.ObjectRef + ): + try: + self._default_result_or_future = ray.get(self._default_result_or_future) + except RayActorError: # error during initialization + self._default_result_or_future = None + if self._default_result_or_future and self.temporary_state.ray_actor: + self.set_location( + _Location( + self._default_result_or_future.get(NODE_IP), + self._default_result_or_future.get(PID), + ) + ) + return self._default_result_or_future + + def resolve_config_placeholders(self, placeholder_resolvers: Dict[Tuple, Any]): + from ray.tune.impl.placeholder import resolve_placeholders + + # Make a copy of the unresolved config before resolve it. + self.config = copy.deepcopy(self.__unresolved_config) + resolve_placeholders(self.config, placeholder_resolvers) + + @property + def last_result(self) -> dict: + # The logic in here is as follows: + # 1. If the trial has reported at least once, last_result would have + # been set and therefore would not be empty. We can just return it. + # 2. If the trial has not reported at least once but we have the + # future for the default results dict, (obtained through + # Trainable.get_auto_filled_metrics), we get that future + # and return it. + # 3. In the worst case where we have nothing, we just set the + # trial_id and return that. + result = self.run_metadata.last_result + if not {k for k in result if k != TRIAL_ID}: + self._get_default_result_or_future() + result = self._default_result_or_future or result + result.setdefault(TRIAL_ID, self.trial_id) + return result + + @property + def metric_analysis(self): + return self.run_metadata.metric_analysis + + @property + def metric_n_steps(self): + return self.run_metadata.metric_n_steps + + def get_ray_actor_ip(self) -> Optional[str]: + if self.temporary_state.location.hostname: + return self.temporary_state.location.hostname + + if not self.temporary_state.ray_actor: + return None + + hostname, pid = ray.get( + self.temporary_state.ray_actor.get_current_ip_pid.remote() + ) + self.temporary_state.location = _Location(hostname, pid) + return self.temporary_state.location.hostname + + @property + @Deprecated("Replaced by `local_experiment_path`") + def local_dir(self): + return self.local_experiment_path + + @property + def experiment_dir_name(self): + return self.storage.experiment_dir_name + + @property + def remote_experiment_path(self) -> str: + return self.storage.experiment_fs_path + + @property + def local_experiment_path(self) -> str: + return self.storage.experiment_driver_staging_path + + @property + @Deprecated("Replaced by `local_path`") + def logdir(self) -> Optional[str]: + # TODO(justinvyu): [Deprecated] Remove in 2.11. + raise DeprecationWarning("Use `local_path` instead of `logdir`.") + + @property + def local_path(self) -> Optional[str]: + return self.storage.trial_driver_staging_path + + @property + def path(self) -> Optional[str]: + return self.storage.trial_fs_path + + @property + def has_reported_at_least_once(self) -> bool: + return bool(self.run_metadata.last_result) + + @property + def node_ip(self): + return self.temporary_state.location.hostname + + @property + def checkpoint_at_end(self): + config = self.run_metadata.checkpoint_manager.checkpoint_config + return config.checkpoint_at_end + + @property + def checkpoint_freq(self): + config = self.run_metadata.checkpoint_manager.checkpoint_config + return config.checkpoint_frequency + + @property + def latest_checkpoint_result(self) -> Optional[_TrainingResult]: + # NOTE: Fallback to the checkpoint passed in from `tune.run(restore)` + # if the trial hasn't saved any checkpoints itself yet. + return ( + self.run_metadata.checkpoint_manager.latest_checkpoint_result + or self._restore_checkpoint_result + ) + + @property + def checkpoint(self) -> Optional[Checkpoint]: + """Returns the most recent checkpoint if one has been saved.""" + return ( + self.latest_checkpoint_result.checkpoint + if self.latest_checkpoint_result + else None + ) + + @classmethod + def generate_id(cls): + return str(uuid.uuid4().hex)[:8] + + def reset(self) -> "Trial": + # If there is `default_resource_request` associated with the trainable, + # clear `resources` and `placement_group_factory`. + # This is mainly relevant for RLlib tuning jobs, where we save users + # of the trouble to specify the resources themselves by having some + # default resources for popular RLlib algorithms. + trainable_cls = self.get_trainable_cls() + clear_resources = trainable_cls and trainable_cls.default_resource_request( + self.config + ) + placement_group_factory = ( + self.placement_group_factory if not clear_resources else None + ) + + checkpoint_config = self.run_metadata.checkpoint_manager.checkpoint_config + return Trial( + self.trainable_name, + config=self.config, + trial_id=None, + evaluated_params=self.evaluated_params, + experiment_tag=self.experiment_tag, + placement_group_factory=placement_group_factory, + stopping_criterion=self.stopping_criterion, + checkpoint_config=checkpoint_config, + export_formats=self.export_formats, + restore_path=self.restore_path, + trial_name_creator=self.trial_name_creator, + trial_dirname_creator=self.trial_dirname_creator, + log_to_file=self.log_to_file, + max_failures=self.max_failures, + storage=self.storage, + ) + + @Deprecated("Replaced by `init_local_path()`") + def init_logdir(self): + # TODO(justinvyu): [Deprecated] Remove in 2.11. + raise DeprecationWarning("Use `init_local_path` instead of `init_logdir`.") + + def init_local_path(self): + """Init logdir.""" + if not self.relative_logdir: + self.relative_logdir = _create_unique_logdir_name( + str(self.local_experiment_path), self._generate_dirname() + ) + # Populate the storage context with the trial dir name we just generated. + self.storage.trial_dir_name = self.relative_logdir + + assert self.local_path + logdir_path = Path(self.local_path) + max_path_length = _get_max_path_length() + if len(str(logdir_path)) >= max_path_length: + logger.warning( + f"The path to the trial log directory is too long " + f"(max length: {max_path_length}. " + f"Consider using `trial_dirname_creator` to shorten the path. " + f"Path: {logdir_path}" + ) + logdir_path.mkdir(parents=True, exist_ok=True) + + self.invalidate_json_state() + + def update_resources(self, resources: Union[dict, PlacementGroupFactory]): + """EXPERIMENTAL: Updates the resource requirements. + + Should only be called when the trial is not running. + + Raises: + ValueError: if trial status is running. + """ + if self.status is Trial.RUNNING: + raise ValueError("Cannot update resources while Trial is running.") + + placement_group_factory = resources + if isinstance(resources, dict): + placement_group_factory = resource_dict_to_pg_factory(resources) + + self.placement_group_factory = placement_group_factory + + self.invalidate_json_state() + + def set_ray_actor(self, ray_actor): + self.temporary_state.ray_actor = ray_actor + if ray_actor: + # Do not block here, the result will be gotten when last_result + # property is accessed + self._default_result_or_future = ray_actor.get_auto_filled_metrics.remote( + debug_metrics_only=True + ) + + def set_location(self, location): + """Sets the location of the trial.""" + self.temporary_state.location = location + + def set_status(self, status): + """Sets the status of the trial.""" + self.status = status + if status == Trial.RUNNING: + if self.run_metadata.start_time is None: + self.run_metadata.start_time = time.time() + self.invalidate_json_state() + + def set_config(self, config): + self.config = config + self.invalidate_json_state() + + def set_experiment_tag(self, experiment_tag): + self.experiment_tag = experiment_tag + self.invalidate_json_state() + + def set_storage(self, new_storage: StorageContext): + """Updates the storage context of the trial. + + If the `storage_path` or `experiment_dir_name` has changed, then this setter + also updates the paths of all checkpoints tracked by the checkpoint manager. + This enables restoration from a checkpoint if the user moves the directory. + """ + original_storage = self.storage + + checkpoint_manager = self.run_metadata.checkpoint_manager + + for checkpoint_result in checkpoint_manager.best_checkpoint_results: + checkpoint_result.checkpoint = Checkpoint( + path=checkpoint_result.checkpoint.path.replace( + original_storage.trial_fs_path, new_storage.trial_fs_path, 1 + ), + filesystem=new_storage.storage_filesystem, + ) + latest_checkpoint_result = checkpoint_manager.latest_checkpoint_result + if latest_checkpoint_result: + latest_checkpoint_result.checkpoint = Checkpoint( + path=latest_checkpoint_result.checkpoint.path.replace( + original_storage.trial_fs_path, new_storage.trial_fs_path, 1 + ), + filesystem=new_storage.storage_filesystem, + ) + + self.storage = new_storage + self.invalidate_json_state() + + @property + def num_failures(self): + return self.run_metadata.num_failures + + @property + def num_failures_after_restore(self): + return self.run_metadata.num_failures_after_restore + + @property + def error_file(self): + if not self.local_path or not self.run_metadata.error_filename: + return None + return Path(self.local_path, self.run_metadata.error_filename).as_posix() + + @property + def pickled_error_file(self): + if not self.local_path or not self.run_metadata.pickled_error_filename: + return None + return Path( + self.local_path, self.run_metadata.pickled_error_filename + ).as_posix() + + def get_pickled_error(self) -> Optional[Exception]: + """Returns the pickled error object if it exists in storage. + + This is a pickled version of the latest error that the trial encountered. + """ + error_filename = self.run_metadata.pickled_error_filename + if error_filename is None: + return None + + fs = self.storage.storage_filesystem + pickled_error_fs_path = Path( + self.storage.trial_fs_path, error_filename + ).as_posix() + + if _exists_at_fs_path(fs=fs, fs_path=pickled_error_fs_path): + with fs.open_input_stream(pickled_error_fs_path) as f: + return cloudpickle.loads(f.readall()) + return None + + def get_error(self) -> Optional[TuneError]: + """Returns the error text file trace as a TuneError object + if it exists in storage. + + This is a text trace of the latest error that the trial encountered, + which is used in the case that the error is not picklable. + """ + error_filename = self.run_metadata.error_filename + if error_filename is None: + return None + + fs = self.storage.storage_filesystem + txt_error_fs_path = Path(self.storage.trial_fs_path, error_filename).as_posix() + + if _exists_at_fs_path(fs=fs, fs_path=txt_error_fs_path): + with fs.open_input_stream(txt_error_fs_path) as f: + return f.readall().decode() + return None + + def _handle_restore_error(self, exc: Exception): + # For Restoration errors, we only increment the restore failure count + # if the number of failures exceeds the restore retry limit. + if self.temporary_state.num_restore_failures >= int( + os.environ.get("TUNE_RESTORE_RETRY_NUM", 0) + ): + self.run_metadata.num_failures += 1 + else: + self.temporary_state.num_restore_failures += 1 + + def _handle_ray_actor_error(self, exc: RayActorError): + count_preemption_errors = bool( + int(os.environ.get(RAY_TRAIN_COUNT_PREEMPTION_AS_FAILURE, "0")) + ) + if not exc.preempted or count_preemption_errors: + # Only count non-preempted actor errors as failures. + self.run_metadata.num_failures += 1 + + def _handle_ray_task_error(self, exc: RayTaskError): + cause = exc.as_instanceof_cause() + if isinstance(cause, RayActorError): + # Handle the RayActorError directly (ex: Ray Train worker actor errors) + return self._handle_ray_actor_error(cause) + + # Increment failures for all user errors (which get raised as RayTaskError) + self.run_metadata.num_failures += 1 + + def handle_error( + self, exc: Optional[Union[TuneError, RayTaskError, RayActorError]] = None + ): + if self.is_restoring: + self._handle_restore_error(exc) + elif isinstance(exc, RayActorError): + self._handle_ray_actor_error(exc) + elif isinstance(exc, RayTaskError): + self._handle_ray_task_error(exc) + else: + self.run_metadata.num_failures += 1 + + if self.local_path: + self.run_metadata.error_filename = EXPR_ERROR_FILE + if isinstance(exc, (RayTaskError, RayActorError)): + # Piping through the actual error to result grid. + self.run_metadata.pickled_error_filename = EXPR_ERROR_PICKLE_FILE + with open(self.pickled_error_file, "wb") as f: + cloudpickle.dump(exc, f) + with open(self.error_file, "a+") as f: + f.write( + "Failure # {} (occurred at {})\n".format( + self.run_metadata.num_failures, date_str() + ) + ) + f.write(str(exc) + "\n") + self.run_metadata.invalidate_cache() + + def should_stop(self, result): + """Whether the given result meets this trial's stopping criteria.""" + if result.get(DONE): + return True + + for criterion, stop_value in self.stopping_criterion.items(): + if isinstance(criterion, dict): + raise ValueError( + "Stopping criteria is now flattened by default. " + "Use forward slashes to nest values `key1/key2/key3`." + ) + elif criterion not in result: + if log_once("tune_trial_stop_criterion_not_found"): + logger.warning( + f"Stopping criterion '{criterion}' not found in result dict! " + f"Available keys are {list(result.keys())}. If '{criterion}' is" + " never reported, the run will continue until training is " + "finished." + ) + elif result[criterion] >= stop_value: + return True + return False + + def should_checkpoint(self): + """Whether this trial is due for checkpointing.""" + result = self.last_result or {} + if result.get(DONE) and self.checkpoint_at_end: + return True + return ( + self.checkpoint_freq + and result.get(TRAINING_ITERATION, 0) % self.checkpoint_freq == 0 + ) + + def has_checkpoint(self) -> bool: + return self.checkpoint is not None + + def on_checkpoint(self, checkpoint_result: _TrainingResult): + """Hook for handling checkpoints taken by the Trainable. + + Args: + checkpoint: Checkpoint taken. + """ + self.run_metadata.checkpoint_manager.register_checkpoint(checkpoint_result) + # Update the checkpoint index to keep the checkpoint index in sync. + # This index will get restored when the trial is restored and will + # be passed to the Trainable as the starting checkpoint index. + self.storage._update_checkpoint_index(checkpoint_result.metrics) + + self.invalidate_json_state() + self.run_metadata.invalidate_cache() + + def on_restore(self): + """Handles restoration completion.""" + assert self.is_restoring + self.run_metadata.last_result = self.temporary_state.restoring_from.metrics + self.run_metadata.last_result.setdefault("config", self.config) + self.temporary_state.restoring_from = None + self.temporary_state.num_restore_failures = 0 + + def should_recover(self): + """Returns whether the trial qualifies for retrying. + + `num_failures` should represent the number of times the trial has + failed *up to the moment this method is called.* If we've failed + 5 times and `max_failures=5`, then we should recover, since + we only pass the limit on the 6th failure. + + Note this may return true even when there is no checkpoint, either because + `self.checkpoint_freq` is `0` or because the trial failed before + a checkpoint has been made. + """ + return ( + self.run_metadata.num_failures <= self.max_failures or self.max_failures < 0 + ) + + def update_last_result(self, result): + if self.experiment_tag: + result.update(experiment_tag=self.experiment_tag) + + self.set_location(_Location(result.get(NODE_IP), result.get(PID))) + self.run_metadata.last_result = result + self.run_metadata.last_result_time = time.time() + + metric_result = self.last_result.copy() + for remove_metric in DEBUG_METRICS: + metric_result.pop(remove_metric, None) + + for metric, value in flatten_dict(metric_result).items(): + if isinstance(value, Number): + self.run_metadata.update_metric( + metric, value, step=result.get("training_iteration") + ) + + def get_trainable_cls(self): + if self.stub: + return None + return get_trainable_cls(self.trainable_name) + + def is_finished(self): + return self.status in [Trial.ERROR, Trial.TERMINATED] + + @property + def is_restoring(self): + return self.temporary_state.restoring_from is not None + + @property + def is_saving(self): + return self.temporary_state.saving_to is not None + + def __repr__(self): + return self._trainable_name(include_trial_id=True) + + def __str__(self): + return self._trainable_name(include_trial_id=True) + + def _trainable_name(self, include_trial_id=False): + """Combines ``env`` with ``trainable_name`` and ``trial_id``. + + Can be overridden with a custom string creator. + """ + if self.custom_trial_name: + return self.custom_trial_name + + if "env" in self.config: + env = self.config["env"] + if isinstance(env, type): + env = env.__name__ + identifier = "{}_{}".format(self.trainable_name, env) + else: + identifier = self.trainable_name + if include_trial_id: + identifier += "_" + self.trial_id + return identifier.replace("/", "_") + + def _generate_dirname(self): + if self.custom_dirname: + generated_dirname = self.custom_dirname + else: + MAX_LEN_IDENTIFIER = int(os.environ.get("TUNE_MAX_LEN_IDENTIFIER", "130")) + generated_dirname = f"{str(self)}_{self.experiment_tag}" + generated_dirname = generated_dirname[:MAX_LEN_IDENTIFIER] + generated_dirname += f"_{date_str()}" + # This is the file path used by rsync. ['/', '(', ')'] are not allowed. + return re.sub("[/()]", "_", generated_dirname) + + def invalidate_json_state(self): + self._state_json = None + + def get_json_state(self) -> Tuple[str, str]: + if self._state_json is None: + state = self.__getstate__() + state.pop("run_metadata", None) + self._state_json = json.dumps(state, indent=2, cls=TuneFunctionEncoder) + + runtime_metadata_json = self.run_metadata.get_json_state() + + return self._state_json, runtime_metadata_json + + @classmethod + def from_json_state(cls, json_state: str, stub: bool = False) -> "Trial": + state = json.loads(json_state, cls=TuneFunctionDecoder) + + new_trial = Trial( + state["trainable_name"], + stub=stub, + _setup_default_resource=False, + ) + + new_trial.__setstate__(state) + + return new_trial + + def restore_run_metadata(self, run_metadata: str): + self.run_metadata = _TrainingRunMetadata.from_json_state(run_metadata) + + @classmethod + def from_directory( + cls, path: Union[str, os.PathLike], stub: bool = False + ) -> "Trial": + metadata_path = Path(path, TRIAL_STATE_FILENAME) + if not metadata_path.exists(): + raise FileNotFoundError( + f"Can't restore trial from path: File `{metadata_path}` not found." + ) + + json_state = metadata_path.read_text() + return cls.from_json_state(json_state, stub=stub) + + def __getstate__(self): + """Memento generator for Trial. + + Sets RUNNING trials to PENDING. + Note this can only occur if the trial holds a PERSISTENT checkpoint. + """ + state = self.__dict__.copy() + + for key in self._nonjson_fields: + state[key] = binary_to_hex(cloudpickle.dumps(state.get(key))) + + state.pop("temporary_state", None) + + state["_state_json"] = None + state["_default_result_or_future"] = None + + return state + + def __setstate__(self, state): + if state["status"] == Trial.RUNNING: + state["status"] = Trial.PENDING + for key in self._nonjson_fields: + if key in state: + state[key] = cloudpickle.loads(hex_to_binary(state[key])) + + # Ensure that stub doesn't get overriden + stub = state.pop("stub", True) + self.__dict__.update(state) + self.stub = stub or getattr(self, "stub", False) + + if not self.stub: + validate_trainable(self.trainable_name) + + self.temporary_state = _TemporaryTrialState() + + assert self.placement_group_factory diff --git a/.venv/lib/python3.11/site-packages/ray/tune/integration/__init__.py b/.venv/lib/python3.11/site-packages/ray/tune/integration/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/ray/tune/integration/__pycache__/keras.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/tune/integration/__pycache__/keras.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bd578c89e1b3076fde978b04552b098def93b4fa Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/tune/integration/__pycache__/keras.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/tune/integration/__pycache__/lightgbm.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/tune/integration/__pycache__/lightgbm.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b484f065baad2c8f7b3a712c12008616caeca8db Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/tune/integration/__pycache__/lightgbm.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/tune/integration/__pycache__/pytorch_lightning.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/tune/integration/__pycache__/pytorch_lightning.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..820b9eb497174745c590c5063b751ce16a222000 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/tune/integration/__pycache__/pytorch_lightning.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/tune/integration/__pycache__/ray_train.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/tune/integration/__pycache__/ray_train.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..caa88ab53cd2323f871d2e8350044c8cfd5a4ced Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/tune/integration/__pycache__/ray_train.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/tune/integration/__pycache__/xgboost.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/tune/integration/__pycache__/xgboost.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8404be42e3deb48aaa8854563dc8456b6f6e42e2 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/tune/integration/__pycache__/xgboost.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/tune/integration/keras.py b/.venv/lib/python3.11/site-packages/ray/tune/integration/keras.py new file mode 100644 index 0000000000000000000000000000000000000000..8733f0205005d60c4c1321b0cabb09dd1315b88b --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/tune/integration/keras.py @@ -0,0 +1,28 @@ +_DEPRECATION_MESSAGE = ( + "The `ray.tune.integration.keras` module is deprecated in favor of " + "`ray.train.tensorflow.keras.ReportCheckpointCallback`." +) + + +class TuneReportCallback: + """Deprecated. + Use :class:`ray.train.tensorflow.keras.ReportCheckpointCallback` instead.""" + + def __new__(cls, *args, **kwargs): + raise DeprecationWarning(_DEPRECATION_MESSAGE) + + +class _TuneCheckpointCallback: + """Deprecated. + Use :class:`ray.train.tensorflow.keras.ReportCheckpointCallback` instead.""" + + def __new__(cls, *args, **kwargs): + raise DeprecationWarning(_DEPRECATION_MESSAGE) + + +class TuneReportCheckpointCallback: + """Deprecated. + Use :class:`ray.train.tensorflow.keras.ReportCheckpointCallback` instead.""" + + def __new__(cls, *args, **kwargs): + raise DeprecationWarning(_DEPRECATION_MESSAGE) diff --git a/.venv/lib/python3.11/site-packages/ray/tune/integration/lightgbm.py b/.venv/lib/python3.11/site-packages/ray/tune/integration/lightgbm.py new file mode 100644 index 0000000000000000000000000000000000000000..778ba5ee2318bf57dc463953b36175ad94374e59 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/tune/integration/lightgbm.py @@ -0,0 +1,14 @@ +from ray.train.lightgbm import ( # noqa: F401 + RayTrainReportCallback as TuneReportCheckpointCallback, +) +from ray.util.annotations import Deprecated + + +@Deprecated +class TuneReportCallback: + def __new__(cls: type, *args, **kwargs): + # TODO(justinvyu): [code_removal] Remove in 2.11. + raise DeprecationWarning( + "`TuneReportCallback` is deprecated. " + "Use `ray.tune.integration.lightgbm.TuneReportCheckpointCallback` instead." + ) diff --git a/.venv/lib/python3.11/site-packages/ray/tune/integration/pytorch_lightning.py b/.venv/lib/python3.11/site-packages/ray/tune/integration/pytorch_lightning.py new file mode 100644 index 0000000000000000000000000000000000000000..0ca554ebced800222886c6b1b6cf00e933ad150a --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/tune/integration/pytorch_lightning.py @@ -0,0 +1,206 @@ +import inspect +import logging +import os +import tempfile +import warnings +from contextlib import contextmanager +from typing import Dict, List, Optional, Type, Union + +from ray import train +from ray.train import Checkpoint +from ray.util import log_once +from ray.util.annotations import Deprecated, PublicAPI + +try: + from lightning import Callback, LightningModule, Trainer +except ModuleNotFoundError: + from pytorch_lightning import Callback, LightningModule, Trainer + + +logger = logging.getLogger(__name__) + +# Get all Pytorch Lightning Callback hooks based on whatever PTL version is being used. +_allowed_hooks = { + name + for name, fn in inspect.getmembers(Callback, predicate=inspect.isfunction) + if name.startswith("on_") +} + + +def _override_ptl_hooks(callback_cls: Type["TuneCallback"]) -> Type["TuneCallback"]: + """Overrides all allowed PTL Callback hooks with our custom handle logic.""" + + def generate_overridden_hook(fn_name): + def overridden_hook( + self, + trainer: Trainer, + *args, + pl_module: Optional[LightningModule] = None, + **kwargs, + ): + if fn_name in self._on: + self._handle(trainer=trainer, pl_module=pl_module) + + return overridden_hook + + # Set the overridden hook to all the allowed hooks in TuneCallback. + for fn_name in _allowed_hooks: + setattr(callback_cls, fn_name, generate_overridden_hook(fn_name)) + + return callback_cls + + +@_override_ptl_hooks +class TuneCallback(Callback): + """Base class for Tune's PyTorch Lightning callbacks. + + Args: + When to trigger checkpoint creations. Must be one of + the PyTorch Lightning event hooks (less the ``on_``), e.g. + "train_batch_start", or "train_end". Defaults to "validation_end" + """ + + def __init__(self, on: Union[str, List[str]] = "validation_end"): + if not isinstance(on, list): + on = [on] + + for hook in on: + if f"on_{hook}" not in _allowed_hooks: + raise ValueError( + f"Invalid hook selected: {hook}. Must be one of " + f"{_allowed_hooks}" + ) + + # Add back the "on_" prefix for internal consistency. + on = [f"on_{hook}" for hook in on] + + self._on = on + + def _handle(self, trainer: Trainer, pl_module: Optional[LightningModule]): + raise NotImplementedError + + +@PublicAPI +class TuneReportCheckpointCallback(TuneCallback): + """PyTorch Lightning report and checkpoint callback + + Saves checkpoints after each validation step. Also reports metrics to Tune, + which is needed for checkpoint registration. + + Args: + metrics: Metrics to report to Tune. If this is a list, + each item describes the metric key reported to PyTorch Lightning, + and it will reported under the same name to Tune. If this is a + dict, each key will be the name reported to Tune and the respective + value will be the metric key reported to PyTorch Lightning. + filename: Filename of the checkpoint within the checkpoint + directory. Defaults to "checkpoint". + save_checkpoints: If True (default), checkpoints will be saved and + reported to Ray. If False, only metrics will be reported. + on: When to trigger checkpoint creations and metric reports. Must be one of + the PyTorch Lightning event hooks (less the ``on_``), e.g. + "train_batch_start", or "train_end". Defaults to "validation_end". + + + Example: + + .. code-block:: python + + import pytorch_lightning as pl + from ray.tune.integration.pytorch_lightning import ( + TuneReportCheckpointCallback) + + # Save checkpoint after each training batch and after each + # validation epoch. + trainer = pl.Trainer(callbacks=[TuneReportCheckpointCallback( + metrics={"loss": "val_loss", "mean_accuracy": "val_acc"}, + filename="trainer.ckpt", on="validation_end")]) + + + """ + + def __init__( + self, + metrics: Optional[Union[str, List[str], Dict[str, str]]] = None, + filename: str = "checkpoint", + save_checkpoints: bool = True, + on: Union[str, List[str]] = "validation_end", + ): + super(TuneReportCheckpointCallback, self).__init__(on=on) + if isinstance(metrics, str): + metrics = [metrics] + self._save_checkpoints = save_checkpoints + self._filename = filename + self._metrics = metrics + + def _get_report_dict(self, trainer: Trainer, pl_module: LightningModule): + # Don't report if just doing initial validation sanity checks. + if trainer.sanity_checking: + return + if not self._metrics: + report_dict = {k: v.item() for k, v in trainer.callback_metrics.items()} + else: + report_dict = {} + for key in self._metrics: + if isinstance(self._metrics, dict): + metric = self._metrics[key] + else: + metric = key + if metric in trainer.callback_metrics: + report_dict[key] = trainer.callback_metrics[metric].item() + else: + logger.warning( + f"Metric {metric} does not exist in " + "`trainer.callback_metrics." + ) + + return report_dict + + @contextmanager + def _get_checkpoint(self, trainer: Trainer) -> Optional[Checkpoint]: + if not self._save_checkpoints: + yield None + return + + with tempfile.TemporaryDirectory() as checkpoint_dir: + trainer.save_checkpoint(os.path.join(checkpoint_dir, self._filename)) + checkpoint = Checkpoint.from_directory(checkpoint_dir) + yield checkpoint + + def _handle(self, trainer: Trainer, pl_module: LightningModule): + if trainer.sanity_checking: + return + + report_dict = self._get_report_dict(trainer, pl_module) + if not report_dict: + return + + with self._get_checkpoint(trainer) as checkpoint: + train.report(report_dict, checkpoint=checkpoint) + + +class _TuneCheckpointCallback(TuneCallback): + def __init__(self, *args, **kwargs): + raise DeprecationWarning( + "`ray.tune.integration.pytorch_lightning._TuneCheckpointCallback` " + "is deprecated." + ) + + +@Deprecated +class TuneReportCallback(TuneReportCheckpointCallback): + def __init__( + self, + metrics: Optional[Union[str, List[str], Dict[str, str]]] = None, + on: Union[str, List[str]] = "validation_end", + ): + if log_once("tune_ptl_report_deprecated"): + warnings.warn( + "`ray.tune.integration.pytorch_lightning.TuneReportCallback` " + "is deprecated. Use " + "`ray.tune.integration.pytorch_lightning.TuneReportCheckpointCallback`" + " instead." + ) + super(TuneReportCallback, self).__init__( + metrics=metrics, save_checkpoints=False, on=on + ) diff --git a/.venv/lib/python3.11/site-packages/ray/tune/integration/ray_train.py b/.venv/lib/python3.11/site-packages/ray/tune/integration/ray_train.py new file mode 100644 index 0000000000000000000000000000000000000000..b947eab01df00e8fe80f721fd838489b1f814b7a --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/tune/integration/ray_train.py @@ -0,0 +1,33 @@ +from typing import Any, Dict, List, Optional + +import ray.tune +from ray.train import Checkpoint as RayTrainCheckpoint +from ray.train.v2._internal.execution.context import TrainRunContext +from ray.train.v2.api.callback import UserCallback +from ray.util.annotations import DeveloperAPI + + +CHECKPOINT_PATH_KEY = "checkpoint_path" + + +@DeveloperAPI +class TuneReportCallback(UserCallback): + """Propagate metrics and checkpoint paths from Ray Train workers to Ray Tune.""" + + def after_report( + self, + run_context: TrainRunContext, + metrics: List[Dict[str, Any]], + checkpoint: Optional[RayTrainCheckpoint], + ): + # TODO: This can be changed to aggregate the metrics from all workers. + # For now, just achieve feature parity with the old Tune+Train integration. + metrics = metrics[0].copy() + + # If a checkpoint is provided, add the checkpoint path to the metrics. + # Don't report the checkpoint again since it's already been uploaded + # to storage. + if checkpoint: + metrics[CHECKPOINT_PATH_KEY] = checkpoint.path + + ray.tune.report(metrics=metrics) diff --git a/.venv/lib/python3.11/site-packages/ray/tune/integration/xgboost.py b/.venv/lib/python3.11/site-packages/ray/tune/integration/xgboost.py new file mode 100644 index 0000000000000000000000000000000000000000..fadb64ec4be1353b0f722c37406c5ba1a048a096 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/tune/integration/xgboost.py @@ -0,0 +1,14 @@ +from ray.train.xgboost import ( # noqa: F401 + RayTrainReportCallback as TuneReportCheckpointCallback, +) +from ray.util.annotations import Deprecated + + +@Deprecated +class TuneReportCallback: + def __new__(cls: type, *args, **kwargs): + # TODO(justinvyu): [code_removal] Remove in 2.11. + raise DeprecationWarning( + "`TuneReportCallback` is deprecated. " + "Use `ray.tune.integration.xgboost.TuneReportCheckpointCallback` instead." + ) diff --git a/.venv/lib/python3.11/site-packages/ray/tune/progress_reporter.py b/.venv/lib/python3.11/site-packages/ray/tune/progress_reporter.py new file mode 100644 index 0000000000000000000000000000000000000000..2d83fb814408abe81a5f6006396ff1db70ef0e4d --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/tune/progress_reporter.py @@ -0,0 +1,1596 @@ +from __future__ import print_function + +import collections +import datetime +import numbers +import sys +import textwrap +import time +import warnings +from pathlib import Path +from typing import Any, Callable, Collection, Dict, List, Optional, Tuple, Union + +import numpy as np +import pandas as pd + +import ray +from ray._private.dict import flatten_dict +from ray._private.thirdparty.tabulate.tabulate import tabulate +from ray.air.constants import EXPR_ERROR_FILE, TRAINING_ITERATION +from ray.air.util.node import _force_on_current_node +from ray.experimental.tqdm_ray import safe_print +from ray.tune.callback import Callback +from ray.tune.experiment.trial import DEBUG_PRINT_INTERVAL, Trial, _Location +from ray.tune.logger import pretty_print +from ray.tune.result import ( + AUTO_RESULT_KEYS, + DEFAULT_METRIC, + DONE, + EPISODE_REWARD_MEAN, + EXPERIMENT_TAG, + MEAN_ACCURACY, + MEAN_LOSS, + NODE_IP, + PID, + TIME_TOTAL_S, + TIMESTEPS_TOTAL, + TRIAL_ID, +) +from ray.tune.trainable import Trainable +from ray.tune.utils import unflattened_lookup +from ray.tune.utils.log import Verbosity, has_verbosity, set_verbosity +from ray.util.annotations import DeveloperAPI, PublicAPI +from ray.util.queue import Empty, Queue +from ray.widgets import Template + +try: + from collections.abc import Mapping, MutableMapping +except ImportError: + from collections import Mapping, MutableMapping + + +IS_NOTEBOOK = ray.widgets.util.in_notebook() + +SKIP_RESULTS_IN_REPORT = {"config", TRIAL_ID, EXPERIMENT_TAG, DONE} + + +@PublicAPI +class ProgressReporter: + """Abstract class for experiment progress reporting. + + `should_report()` is called to determine whether or not `report()` should + be called. Tune will call these functions after trial state transitions, + receiving training results, and so on. + """ + + def setup( + self, + start_time: Optional[float] = None, + total_samples: Optional[int] = None, + metric: Optional[str] = None, + mode: Optional[str] = None, + **kwargs, + ): + """Setup progress reporter for a new Ray Tune run. + + This function is used to initialize parameters that are set on runtime. + It will be called before any of the other methods. + + Defaults to no-op. + + Args: + start_time: Timestamp when the Ray Tune run is started. + total_samples: Number of samples the Ray Tune run will run. + metric: Metric to optimize. + mode: Must be one of [min, max]. Determines whether objective is + minimizing or maximizing the metric attribute. + **kwargs: Keyword arguments for forward-compatibility. + """ + pass + + def should_report(self, trials: List[Trial], done: bool = False): + """Returns whether or not progress should be reported. + + Args: + trials: Trials to report on. + done: Whether this is the last progress report attempt. + """ + raise NotImplementedError + + def report(self, trials: List[Trial], done: bool, *sys_info: Dict): + """Reports progress across trials. + + Args: + trials: Trials to report on. + done: Whether this is the last progress report attempt. + sys_info: System info. + """ + raise NotImplementedError + + +@DeveloperAPI +class TuneReporterBase(ProgressReporter): + """Abstract base class for the default Tune reporters. + + If metric_columns is not overridden, Tune will attempt to automatically + infer the metrics being outputted, up to 'infer_limit' number of + metrics. + + Args: + metric_columns: Names of metrics to + include in progress table. If this is a dict, the keys should + be metric names and the values should be the displayed names. + If this is a list, the metric name is used directly. + parameter_columns: Names of parameters to + include in progress table. If this is a dict, the keys should + be parameter names and the values should be the displayed names. + If this is a list, the parameter name is used directly. If empty, + defaults to all available parameters. + max_progress_rows: Maximum number of rows to print + in the progress table. The progress table describes the + progress of each trial. Defaults to 20. + max_error_rows: Maximum number of rows to print in the + error table. The error table lists the error file, if any, + corresponding to each trial. Defaults to 20. + max_column_length: Maximum column length (in characters). Column + headers and values longer than this will be abbreviated. + max_report_frequency: Maximum report frequency in seconds. + Defaults to 5s. + infer_limit: Maximum number of metrics to automatically infer + from tune results. + print_intermediate_tables: Print intermediate result + tables. If None (default), will be set to True for verbosity + levels above 3, otherwise False. If True, intermediate tables + will be printed with experiment progress. If False, tables + will only be printed at then end of the tuning run for verbosity + levels greater than 2. + metric: Metric used to determine best current trial. + mode: One of [min, max]. Determines whether objective is + minimizing or maximizing the metric attribute. + sort_by_metric: Sort terminated trials by metric in the + intermediate table. Defaults to False. + """ + + # Truncated representations of column names (to accommodate small screens). + DEFAULT_COLUMNS = collections.OrderedDict( + { + MEAN_ACCURACY: "acc", + MEAN_LOSS: "loss", + TRAINING_ITERATION: "iter", + TIME_TOTAL_S: "total time (s)", + TIMESTEPS_TOTAL: "ts", + EPISODE_REWARD_MEAN: "reward", + } + ) + VALID_SUMMARY_TYPES = { + int, + float, + np.float32, + np.float64, + np.int32, + np.int64, + type(None), + } + + def __init__( + self, + *, + metric_columns: Optional[Union[List[str], Dict[str, str]]] = None, + parameter_columns: Optional[Union[List[str], Dict[str, str]]] = None, + total_samples: Optional[int] = None, + max_progress_rows: int = 20, + max_error_rows: int = 20, + max_column_length: int = 20, + max_report_frequency: int = 5, + infer_limit: int = 3, + print_intermediate_tables: Optional[bool] = None, + metric: Optional[str] = None, + mode: Optional[str] = None, + sort_by_metric: bool = False, + ): + self._total_samples = total_samples + self._metrics_override = metric_columns is not None + self._inferred_metrics = {} + self._metric_columns = metric_columns or self.DEFAULT_COLUMNS.copy() + self._parameter_columns = parameter_columns or [] + self._max_progress_rows = max_progress_rows + self._max_error_rows = max_error_rows + self._max_column_length = max_column_length + self._infer_limit = infer_limit + + if print_intermediate_tables is None: + self._print_intermediate_tables = has_verbosity(Verbosity.V3_TRIAL_DETAILS) + else: + self._print_intermediate_tables = print_intermediate_tables + + self._max_report_freqency = max_report_frequency + self._last_report_time = 0 + + self._start_time = time.time() + + self._metric = metric + self._mode = mode + self._sort_by_metric = sort_by_metric + + def setup( + self, + start_time: Optional[float] = None, + total_samples: Optional[int] = None, + metric: Optional[str] = None, + mode: Optional[str] = None, + **kwargs, + ): + self.set_start_time(start_time) + self.set_total_samples(total_samples) + self.set_search_properties(metric=metric, mode=mode) + + def set_search_properties(self, metric: Optional[str], mode: Optional[str]): + if (self._metric and metric) or (self._mode and mode): + raise ValueError( + "You passed a `metric` or `mode` argument to `tune.TuneConfig()`, but " + "the reporter you are using was already instantiated with their " + "own `metric` and `mode` parameters. Either remove the arguments " + "from your reporter or from your call to `tune.TuneConfig()`" + ) + + if metric: + self._metric = metric + if mode: + self._mode = mode + + if self._metric is None and self._mode: + # If only a mode was passed, use anonymous metric + self._metric = DEFAULT_METRIC + + return True + + def set_total_samples(self, total_samples: int): + self._total_samples = total_samples + + def set_start_time(self, timestamp: Optional[float] = None): + if timestamp is not None: + self._start_time = time.time() + else: + self._start_time = timestamp + + def should_report(self, trials: List[Trial], done: bool = False): + if time.time() - self._last_report_time > self._max_report_freqency: + self._last_report_time = time.time() + return True + return done + + def add_metric_column(self, metric: str, representation: Optional[str] = None): + """Adds a metric to the existing columns. + + Args: + metric: Metric to add. This must be a metric being returned + in training step results. + representation: Representation to use in table. Defaults to + `metric`. + """ + self._metrics_override = True + if metric in self._metric_columns: + raise ValueError("Column {} already exists.".format(metric)) + + if isinstance(self._metric_columns, MutableMapping): + representation = representation or metric + self._metric_columns[metric] = representation + else: + if representation is not None and representation != metric: + raise ValueError( + "`representation` cannot differ from `metric` " + "if this reporter was initialized with a list " + "of metric columns." + ) + self._metric_columns.append(metric) + + def add_parameter_column( + self, parameter: str, representation: Optional[str] = None + ): + """Adds a parameter to the existing columns. + + Args: + parameter: Parameter to add. This must be a parameter + specified in the configuration. + representation: Representation to use in table. Defaults to + `parameter`. + """ + if parameter in self._parameter_columns: + raise ValueError("Column {} already exists.".format(parameter)) + + if isinstance(self._parameter_columns, MutableMapping): + representation = representation or parameter + self._parameter_columns[parameter] = representation + else: + if representation is not None and representation != parameter: + raise ValueError( + "`representation` cannot differ from `parameter` " + "if this reporter was initialized with a list " + "of metric columns." + ) + self._parameter_columns.append(parameter) + + def _progress_str( + self, + trials: List[Trial], + done: bool, + *sys_info: Dict, + fmt: str = "psql", + delim: str = "\n", + ): + """Returns full progress string. + + This string contains a progress table and error table. The progress + table describes the progress of each trial. The error table lists + the error file, if any, corresponding to each trial. The latter only + exists if errors have occurred. + + Args: + trials: Trials to report on. + done: Whether this is the last progress report attempt. + fmt: Table format. See `tablefmt` in tabulate API. + delim: Delimiter between messages. + """ + if self._sort_by_metric and (self._metric is None or self._mode is None): + self._sort_by_metric = False + warnings.warn( + "Both 'metric' and 'mode' must be set to be able " + "to sort by metric. No sorting is performed." + ) + if not self._metrics_override: + user_metrics = self._infer_user_metrics(trials, self._infer_limit) + self._metric_columns.update(user_metrics) + messages = [ + "== Status ==", + _time_passed_str(self._start_time, time.time()), + *sys_info, + ] + if done: + max_progress = None + max_error = None + else: + max_progress = self._max_progress_rows + max_error = self._max_error_rows + + current_best_trial, metric = self._current_best_trial(trials) + if current_best_trial: + messages.append( + _best_trial_str(current_best_trial, metric, self._parameter_columns) + ) + + if has_verbosity(Verbosity.V1_EXPERIMENT): + # Will filter the table in `trial_progress_str` + messages.append( + _trial_progress_str( + trials, + metric_columns=self._metric_columns, + parameter_columns=self._parameter_columns, + total_samples=self._total_samples, + force_table=self._print_intermediate_tables, + fmt=fmt, + max_rows=max_progress, + max_column_length=self._max_column_length, + done=done, + metric=self._metric, + mode=self._mode, + sort_by_metric=self._sort_by_metric, + ) + ) + messages.append(_trial_errors_str(trials, fmt=fmt, max_rows=max_error)) + + return delim.join(messages) + delim + + def _infer_user_metrics(self, trials: List[Trial], limit: int = 4): + """Try to infer the metrics to print out.""" + if len(self._inferred_metrics) >= limit: + return self._inferred_metrics + self._inferred_metrics = {} + for t in trials: + if not t.last_result: + continue + for metric, value in t.last_result.items(): + if metric not in self.DEFAULT_COLUMNS: + if metric not in AUTO_RESULT_KEYS: + if type(value) in self.VALID_SUMMARY_TYPES: + self._inferred_metrics[metric] = metric + + if len(self._inferred_metrics) >= limit: + return self._inferred_metrics + return self._inferred_metrics + + def _current_best_trial(self, trials: List[Trial]): + if not trials: + return None, None + + metric, mode = self._metric, self._mode + # If no metric has been set, see if exactly one has been reported + # and use that one. `mode` must still be set. + if not metric: + if len(self._inferred_metrics) == 1: + metric = list(self._inferred_metrics.keys())[0] + + if not metric or not mode: + return None, metric + + metric_op = 1.0 if mode == "max" else -1.0 + best_metric = float("-inf") + best_trial = None + for t in trials: + if not t.last_result: + continue + metric_value = unflattened_lookup(metric, t.last_result, default=None) + if pd.isnull(metric_value): + continue + if not best_trial or metric_value * metric_op > best_metric: + best_metric = metric_value * metric_op + best_trial = t + return best_trial, metric + + +@DeveloperAPI +class RemoteReporterMixin: + """Remote reporter abstract mixin class. + + Subclasses of this class will use a Ray Queue to display output + on the driver side when running Ray Client.""" + + @property + def output_queue(self) -> Queue: + return getattr(self, "_output_queue", None) + + @output_queue.setter + def output_queue(self, value: Queue): + self._output_queue = value + + def display(self, string: str) -> None: + """Display the progress string. + + Args: + string: String to display. + """ + raise NotImplementedError + + +@PublicAPI +class JupyterNotebookReporter(TuneReporterBase, RemoteReporterMixin): + """Jupyter notebook-friendly Reporter that can update display in-place. + + Args: + overwrite: Flag for overwriting the cell contents before initialization. + metric_columns: Names of metrics to + include in progress table. If this is a dict, the keys should + be metric names and the values should be the displayed names. + If this is a list, the metric name is used directly. + parameter_columns: Names of parameters to + include in progress table. If this is a dict, the keys should + be parameter names and the values should be the displayed names. + If this is a list, the parameter name is used directly. If empty, + defaults to all available parameters. + max_progress_rows: Maximum number of rows to print + in the progress table. The progress table describes the + progress of each trial. Defaults to 20. + max_error_rows: Maximum number of rows to print in the + error table. The error table lists the error file, if any, + corresponding to each trial. Defaults to 20. + max_column_length: Maximum column length (in characters). Column + headers and values longer than this will be abbreviated. + max_report_frequency: Maximum report frequency in seconds. + Defaults to 5s. + infer_limit: Maximum number of metrics to automatically infer + from tune results. + print_intermediate_tables: Print intermediate result + tables. If None (default), will be set to True for verbosity + levels above 3, otherwise False. If True, intermediate tables + will be printed with experiment progress. If False, tables + will only be printed at then end of the tuning run for verbosity + levels greater than 2. + metric: Metric used to determine best current trial. + mode: One of [min, max]. Determines whether objective is + minimizing or maximizing the metric attribute. + sort_by_metric: Sort terminated trials by metric in the + intermediate table. Defaults to False. + """ + + def __init__( + self, + *, + overwrite: bool = True, + metric_columns: Optional[Union[List[str], Dict[str, str]]] = None, + parameter_columns: Optional[Union[List[str], Dict[str, str]]] = None, + total_samples: Optional[int] = None, + max_progress_rows: int = 20, + max_error_rows: int = 20, + max_column_length: int = 20, + max_report_frequency: int = 5, + infer_limit: int = 3, + print_intermediate_tables: Optional[bool] = None, + metric: Optional[str] = None, + mode: Optional[str] = None, + sort_by_metric: bool = False, + ): + super(JupyterNotebookReporter, self).__init__( + metric_columns=metric_columns, + parameter_columns=parameter_columns, + total_samples=total_samples, + max_progress_rows=max_progress_rows, + max_error_rows=max_error_rows, + max_column_length=max_column_length, + max_report_frequency=max_report_frequency, + infer_limit=infer_limit, + print_intermediate_tables=print_intermediate_tables, + metric=metric, + mode=mode, + sort_by_metric=sort_by_metric, + ) + + if not IS_NOTEBOOK: + warnings.warn( + "You are using the `JupyterNotebookReporter`, but not " + "IPython/Jupyter-compatible environment was detected. " + "If this leads to unformatted output (e.g. like " + "), consider passing " + "a `CLIReporter` as the `progress_reporter` argument " + "to `train.RunConfig()` instead." + ) + + self._overwrite = overwrite + self._display_handle = None + self.display("") # initialize empty display to update later + + def report(self, trials: List[Trial], done: bool, *sys_info: Dict): + progress = self._progress_html(trials, done, *sys_info) + + if self.output_queue is not None: + # If an output queue is set, send string + self.output_queue.put(progress) + else: + # Else, output directly + self.display(progress) + + def display(self, string: str) -> None: + from IPython.display import HTML, clear_output, display + + if not self._display_handle: + if self._overwrite: + clear_output(wait=True) + self._display_handle = display(HTML(string), display_id=True) + else: + self._display_handle.update(HTML(string)) + + def _progress_html(self, trials: List[Trial], done: bool, *sys_info) -> str: + """Generate an HTML-formatted progress update. + + Args: + trials: List of trials for which progress should be + displayed + done: True if the trials are finished, False otherwise + *sys_info: System information to be displayed + + Returns: + Progress update to be rendered in a notebook, including HTML + tables and formatted error messages. Includes + - Duration of the tune job + - Memory consumption + - Trial progress table, with information about each experiment + """ + if not self._metrics_override: + user_metrics = self._infer_user_metrics(trials, self._infer_limit) + self._metric_columns.update(user_metrics) + + current_time, running_for = _get_time_str(self._start_time, time.time()) + used_gb, total_gb, memory_message = _get_memory_usage() + + status_table = tabulate( + [ + ("Current time:", current_time), + ("Running for:", running_for), + ("Memory:", f"{used_gb}/{total_gb} GiB"), + ], + tablefmt="html", + ) + trial_progress_data = _trial_progress_table( + trials=trials, + metric_columns=self._metric_columns, + parameter_columns=self._parameter_columns, + fmt="html", + max_rows=None if done else self._max_progress_rows, + metric=self._metric, + mode=self._mode, + sort_by_metric=self._sort_by_metric, + max_column_length=self._max_column_length, + ) + + trial_progress = trial_progress_data[0] + trial_progress_messages = trial_progress_data[1:] + trial_errors = _trial_errors_str( + trials, fmt="html", max_rows=None if done else self._max_error_rows + ) + + if any([memory_message, trial_progress_messages, trial_errors]): + msg = Template("tune_status_messages.html.j2").render( + memory_message=memory_message, + trial_progress_messages=trial_progress_messages, + trial_errors=trial_errors, + ) + else: + msg = None + + return Template("tune_status.html.j2").render( + status_table=status_table, + sys_info_message=_generate_sys_info_str(*sys_info), + trial_progress=trial_progress, + messages=msg, + ) + + +@PublicAPI +class CLIReporter(TuneReporterBase): + """Command-line reporter + + Args: + metric_columns: Names of metrics to + include in progress table. If this is a dict, the keys should + be metric names and the values should be the displayed names. + If this is a list, the metric name is used directly. + parameter_columns: Names of parameters to + include in progress table. If this is a dict, the keys should + be parameter names and the values should be the displayed names. + If this is a list, the parameter name is used directly. If empty, + defaults to all available parameters. + max_progress_rows: Maximum number of rows to print + in the progress table. The progress table describes the + progress of each trial. Defaults to 20. + max_error_rows: Maximum number of rows to print in the + error table. The error table lists the error file, if any, + corresponding to each trial. Defaults to 20. + max_column_length: Maximum column length (in characters). Column + headers and values longer than this will be abbreviated. + max_report_frequency: Maximum report frequency in seconds. + Defaults to 5s. + infer_limit: Maximum number of metrics to automatically infer + from tune results. + print_intermediate_tables: Print intermediate result + tables. If None (default), will be set to True for verbosity + levels above 3, otherwise False. If True, intermediate tables + will be printed with experiment progress. If False, tables + will only be printed at then end of the tuning run for verbosity + levels greater than 2. + metric: Metric used to determine best current trial. + mode: One of [min, max]. Determines whether objective is + minimizing or maximizing the metric attribute. + sort_by_metric: Sort terminated trials by metric in the + intermediate table. Defaults to False. + """ + + def __init__( + self, + *, + metric_columns: Optional[Union[List[str], Dict[str, str]]] = None, + parameter_columns: Optional[Union[List[str], Dict[str, str]]] = None, + total_samples: Optional[int] = None, + max_progress_rows: int = 20, + max_error_rows: int = 20, + max_column_length: int = 20, + max_report_frequency: int = 5, + infer_limit: int = 3, + print_intermediate_tables: Optional[bool] = None, + metric: Optional[str] = None, + mode: Optional[str] = None, + sort_by_metric: bool = False, + ): + super(CLIReporter, self).__init__( + metric_columns=metric_columns, + parameter_columns=parameter_columns, + total_samples=total_samples, + max_progress_rows=max_progress_rows, + max_error_rows=max_error_rows, + max_column_length=max_column_length, + max_report_frequency=max_report_frequency, + infer_limit=infer_limit, + print_intermediate_tables=print_intermediate_tables, + metric=metric, + mode=mode, + sort_by_metric=sort_by_metric, + ) + + def _print(self, msg: str): + safe_print(msg) + + def report(self, trials: List[Trial], done: bool, *sys_info: Dict): + self._print(self._progress_str(trials, done, *sys_info)) + + +def _get_memory_usage() -> Tuple[float, float, Optional[str]]: + """Get the current memory consumption. + + Returns: + Memory used, memory available, and optionally a warning + message to be shown to the user when memory consumption is higher + than 90% or if `psutil` is not installed + """ + try: + import ray # noqa F401 + + import psutil + + total_gb = psutil.virtual_memory().total / (1024**3) + used_gb = total_gb - psutil.virtual_memory().available / (1024**3) + if used_gb > total_gb * 0.9: + message = ( + ": ***LOW MEMORY*** less than 10% of the memory on " + "this node is available for use. This can cause " + "unexpected crashes. Consider " + "reducing the memory used by your application " + "or reducing the Ray object store size by setting " + "`object_store_memory` when calling `ray.init`." + ) + else: + message = None + + return round(used_gb, 1), round(total_gb, 1), message + except ImportError: + return ( + np.nan, + np.nan, + "Unknown memory usage. Please run `pip install psutil` to resolve", + ) + + +def _get_time_str(start_time: float, current_time: float) -> Tuple[str, str]: + """Get strings representing the current and elapsed time. + + Args: + start_time: POSIX timestamp of the start of the tune run + current_time: POSIX timestamp giving the current time + + Returns: + Current time and elapsed time for the current run + """ + current_time_dt = datetime.datetime.fromtimestamp(current_time) + start_time_dt = datetime.datetime.fromtimestamp(start_time) + delta: datetime.timedelta = current_time_dt - start_time_dt + + rest = delta.total_seconds() + days = rest // (60 * 60 * 24) + + rest -= days * (60 * 60 * 24) + hours = rest // (60 * 60) + + rest -= hours * (60 * 60) + minutes = rest // 60 + + seconds = rest - minutes * 60 + + if days > 0: + running_for_str = f"{days:.0f} days, " + else: + running_for_str = "" + + running_for_str += f"{hours:02.0f}:{minutes:02.0f}:{seconds:05.2f}" + + return f"{current_time_dt:%Y-%m-%d %H:%M:%S}", running_for_str + + +def _time_passed_str(start_time: float, current_time: float) -> str: + """Generate a message describing the current and elapsed time in the run. + + Args: + start_time: POSIX timestamp of the start of the tune run + current_time: POSIX timestamp giving the current time + + Returns: + Message with the current and elapsed time for the current tune run, + formatted to be displayed to the user + """ + current_time_str, running_for_str = _get_time_str(start_time, current_time) + return f"Current time: {current_time_str} " f"(running for {running_for_str})" + + +def _get_trials_by_state(trials: List[Trial]): + trials_by_state = collections.defaultdict(list) + for t in trials: + trials_by_state[t.status].append(t) + return trials_by_state + + +def _trial_progress_str( + trials: List[Trial], + metric_columns: Union[List[str], Dict[str, str]], + parameter_columns: Optional[Union[List[str], Dict[str, str]]] = None, + total_samples: int = 0, + force_table: bool = False, + fmt: str = "psql", + max_rows: Optional[int] = None, + max_column_length: int = 20, + done: bool = False, + metric: Optional[str] = None, + mode: Optional[str] = None, + sort_by_metric: bool = False, +): + """Returns a human readable message for printing to the console. + + This contains a table where each row represents a trial, its parameters + and the current values of its metrics. + + Args: + trials: List of trials to get progress string for. + metric_columns: Names of metrics to include. + If this is a dict, the keys are metric names and the values are + the names to use in the message. If this is a list, the metric + name is used in the message directly. + parameter_columns: Names of parameters to + include. If this is a dict, the keys are parameter names and the + values are the names to use in the message. If this is a list, + the parameter name is used in the message directly. If this is + empty, all parameters are used in the message. + total_samples: Total number of trials that will be generated. + force_table: Force printing a table. If False, a table will + be printed only at the end of the training for verbosity levels + above `Verbosity.V2_TRIAL_NORM`. + fmt: Output format (see tablefmt in tabulate API). + max_rows: Maximum number of rows in the trial table. Defaults to + unlimited. + max_column_length: Maximum column length (in characters). + done: True indicates that the tuning run finished. + metric: Metric used to sort trials. + mode: One of [min, max]. Determines whether objective is + minimizing or maximizing the metric attribute. + sort_by_metric: Sort terminated trials by metric in the + intermediate table. Defaults to False. + """ + messages = [] + delim = "
" if fmt == "html" else "\n" + if len(trials) < 1: + return delim.join(messages) + + num_trials = len(trials) + trials_by_state = _get_trials_by_state(trials) + + for local_dir in sorted({t.local_experiment_path for t in trials}): + messages.append("Result logdir: {}".format(local_dir)) + + num_trials_strs = [ + "{} {}".format(len(trials_by_state[state]), state) + for state in sorted(trials_by_state) + ] + + if total_samples and total_samples >= sys.maxsize: + total_samples = "infinite" + + messages.append( + "Number of trials: {}{} ({})".format( + num_trials, + f"/{total_samples}" if total_samples else "", + ", ".join(num_trials_strs), + ) + ) + + if force_table or (has_verbosity(Verbosity.V2_TRIAL_NORM) and done): + messages += _trial_progress_table( + trials=trials, + metric_columns=metric_columns, + parameter_columns=parameter_columns, + fmt=fmt, + max_rows=max_rows, + metric=metric, + mode=mode, + sort_by_metric=sort_by_metric, + max_column_length=max_column_length, + ) + + return delim.join(messages) + + +def _max_len( + value: Any, max_len: int = 20, add_addr: bool = False, wrap: bool = False +) -> Any: + """Abbreviate a string representation of an object to `max_len` characters. + + For numbers, booleans and None, the original value will be returned for + correct rendering in the table formatting tool. + + Args: + value: Object to be represented as a string. + max_len: Maximum return string length. + add_addr: If True, will add part of the object address to the end of the + string, e.g. to identify different instances of the same class. If + False, three dots (``...``) will be used instead. + """ + if value is None or isinstance(value, (int, float, numbers.Number, bool)): + return value + + string = str(value) + if len(string) <= max_len: + return string + + if wrap: + # Maximum two rows. + # Todo: Make this configurable in the refactor + if len(value) > max_len * 2: + value = "..." + string[(3 - (max_len * 2)) :] + + wrapped = textwrap.wrap(value, width=max_len) + return "\n".join(wrapped) + + if add_addr and not isinstance(value, (int, float, bool)): + result = f"{string[: (max_len - 5)]}_{hex(id(value))[-4:]}" + return result + + result = "..." + string[(3 - max_len) :] + return result + + +def _get_progress_table_data( + trials: List[Trial], + metric_columns: Union[List[str], Dict[str, str]], + parameter_columns: Optional[Union[List[str], Dict[str, str]]] = None, + max_rows: Optional[int] = None, + metric: Optional[str] = None, + mode: Optional[str] = None, + sort_by_metric: bool = False, + max_column_length: int = 20, +) -> Tuple[List, List[str], Tuple[bool, str]]: + """Generate a table showing the current progress of tuning trials. + + Args: + trials: List of trials for which progress is to be shown. + metric_columns: Metrics to be displayed in the table. + parameter_columns: List of parameters to be included in the data + max_rows: Maximum number of rows to show. If there's overflow, a + message will be shown to the user indicating that some rows + are not displayed + metric: Metric which is being tuned + mode: Sort the table in descending order if mode is "max"; + ascending otherwise + sort_by_metric: If true, the table will be sorted by the metric + max_column_length: Max number of characters in each column + + Returns: + - Trial data + - List of column names + - Overflow tuple: + - boolean indicating whether the table has rows which are hidden + - string with info about the overflowing rows + """ + num_trials = len(trials) + trials_by_state = _get_trials_by_state(trials) + + # Sort terminated trials by metric and mode, descending if mode is "max" + if sort_by_metric: + trials_by_state[Trial.TERMINATED] = sorted( + trials_by_state[Trial.TERMINATED], + reverse=(mode == "max"), + key=lambda t: unflattened_lookup(metric, t.last_result, default=None), + ) + + state_tbl_order = [ + Trial.RUNNING, + Trial.PAUSED, + Trial.PENDING, + Trial.TERMINATED, + Trial.ERROR, + ] + max_rows = max_rows or float("inf") + if num_trials > max_rows: + # TODO(ujvl): suggestion for users to view more rows. + trials_by_state_trunc = _fair_filter_trials( + trials_by_state, max_rows, sort_by_metric + ) + trials = [] + overflow_strs = [] + for state in state_tbl_order: + if state not in trials_by_state: + continue + trials += trials_by_state_trunc[state] + num = len(trials_by_state[state]) - len(trials_by_state_trunc[state]) + if num > 0: + overflow_strs.append("{} {}".format(num, state)) + # Build overflow string. + overflow = num_trials - max_rows + overflow_str = ", ".join(overflow_strs) + else: + overflow = False + overflow_str = "" + trials = [] + for state in state_tbl_order: + if state not in trials_by_state: + continue + trials += trials_by_state[state] + + # Pre-process trials to figure out what columns to show. + if isinstance(metric_columns, Mapping): + metric_keys = list(metric_columns.keys()) + else: + metric_keys = metric_columns + + metric_keys = [ + k + for k in metric_keys + if any( + unflattened_lookup(k, t.last_result, default=None) is not None + for t in trials + ) + ] + + if not parameter_columns: + parameter_keys = sorted(set().union(*[t.evaluated_params for t in trials])) + elif isinstance(parameter_columns, Mapping): + parameter_keys = list(parameter_columns.keys()) + else: + parameter_keys = parameter_columns + + # Build trial rows. + trial_table = [ + _get_trial_info( + trial, parameter_keys, metric_keys, max_column_length=max_column_length + ) + for trial in trials + ] + # Format column headings + if isinstance(metric_columns, Mapping): + formatted_metric_columns = [ + _max_len( + metric_columns[k], max_len=max_column_length, add_addr=False, wrap=True + ) + for k in metric_keys + ] + else: + formatted_metric_columns = [ + _max_len(k, max_len=max_column_length, add_addr=False, wrap=True) + for k in metric_keys + ] + if isinstance(parameter_columns, Mapping): + formatted_parameter_columns = [ + _max_len( + parameter_columns[k], + max_len=max_column_length, + add_addr=False, + wrap=True, + ) + for k in parameter_keys + ] + else: + formatted_parameter_columns = [ + _max_len(k, max_len=max_column_length, add_addr=False, wrap=True) + for k in parameter_keys + ] + columns = ( + ["Trial name", "status", "loc"] + + formatted_parameter_columns + + formatted_metric_columns + ) + + return trial_table, columns, (overflow, overflow_str) + + +def _trial_progress_table( + trials: List[Trial], + metric_columns: Union[List[str], Dict[str, str]], + parameter_columns: Optional[Union[List[str], Dict[str, str]]] = None, + fmt: str = "psql", + max_rows: Optional[int] = None, + metric: Optional[str] = None, + mode: Optional[str] = None, + sort_by_metric: bool = False, + max_column_length: int = 20, +) -> List[str]: + """Generate a list of trial progress table messages. + + Args: + trials: List of trials for which progress is to be shown. + metric_columns: Metrics to be displayed in the table. + parameter_columns: List of parameters to be included in the data + fmt: Format of the table; passed to tabulate as the fmtstr argument + max_rows: Maximum number of rows to show. If there's overflow, a + message will be shown to the user indicating that some rows + are not displayed + metric: Metric which is being tuned + mode: Sort the table in descenting order if mode is "max"; + ascending otherwise + sort_by_metric: If true, the table will be sorted by the metric + max_column_length: Max number of characters in each column + + Returns: + Messages to be shown to the user containing progress tables + """ + data, columns, (overflow, overflow_str) = _get_progress_table_data( + trials, + metric_columns, + parameter_columns, + max_rows, + metric, + mode, + sort_by_metric, + max_column_length, + ) + messages = [tabulate(data, headers=columns, tablefmt=fmt, showindex=False)] + if overflow: + messages.append(f"... {overflow} more trials not shown ({overflow_str})") + return messages + + +def _generate_sys_info_str(*sys_info) -> str: + """Format system info into a string. + *sys_info: System info strings to be included. + + Returns: + Formatted string containing system information. + """ + if sys_info: + return "
".join(sys_info).replace("\n", "
") + return "" + + +def _trial_errors_str( + trials: List[Trial], fmt: str = "psql", max_rows: Optional[int] = None +): + """Returns a readable message regarding trial errors. + + Args: + trials: List of trials to get progress string for. + fmt: Output format (see tablefmt in tabulate API). + max_rows: Maximum number of rows in the error table. Defaults to + unlimited. + """ + messages = [] + failed = [t for t in trials if t.error_file] + num_failed = len(failed) + if num_failed > 0: + messages.append("Number of errored trials: {}".format(num_failed)) + if num_failed > (max_rows or float("inf")): + messages.append( + "Table truncated to {} rows ({} overflow)".format( + max_rows, num_failed - max_rows + ) + ) + + fail_header = ["Trial name", "# failures", "error file"] + fail_table_data = [ + [ + str(trial), + str(trial.run_metadata.num_failures) + + ("" if trial.status == Trial.ERROR else "*"), + trial.error_file, + ] + for trial in failed[:max_rows] + ] + messages.append( + tabulate( + fail_table_data, + headers=fail_header, + tablefmt=fmt, + showindex=False, + colalign=("left", "right", "left"), + ) + ) + if any(trial.status == Trial.TERMINATED for trial in failed[:max_rows]): + messages.append("* The trial terminated successfully after retrying.") + + delim = "
" if fmt == "html" else "\n" + return delim.join(messages) + + +def _best_trial_str( + trial: Trial, + metric: str, + parameter_columns: Optional[Union[List[str], Dict[str, str]]] = None, +): + """Returns a readable message stating the current best trial.""" + val = unflattened_lookup(metric, trial.last_result, default=None) + config = trial.last_result.get("config", {}) + parameter_columns = parameter_columns or list(config.keys()) + if isinstance(parameter_columns, Mapping): + parameter_columns = parameter_columns.keys() + params = {p: unflattened_lookup(p, config) for p in parameter_columns} + return ( + f"Current best trial: {trial.trial_id} with {metric}={val} and " + f"parameters={params}" + ) + + +def _fair_filter_trials( + trials_by_state: Dict[str, List[Trial]], + max_trials: int, + sort_by_metric: bool = False, +): + """Filters trials such that each state is represented fairly. + + The oldest trials are truncated if necessary. + + Args: + trials_by_state: Maximum number of trials to return. + Returns: + Dict mapping state to List of fairly represented trials. + """ + num_trials_by_state = collections.defaultdict(int) + no_change = False + # Determine number of trials to keep per state. + while max_trials > 0 and not no_change: + no_change = True + for state in sorted(trials_by_state): + if num_trials_by_state[state] < len(trials_by_state[state]): + no_change = False + max_trials -= 1 + num_trials_by_state[state] += 1 + # Sort by start time, descending if the trails is not sorted by metric. + sorted_trials_by_state = dict() + for state in sorted(trials_by_state): + if state == Trial.TERMINATED and sort_by_metric: + sorted_trials_by_state[state] = trials_by_state[state] + else: + sorted_trials_by_state[state] = sorted( + trials_by_state[state], reverse=False, key=lambda t: t.trial_id + ) + # Truncate oldest trials. + filtered_trials = { + state: sorted_trials_by_state[state][: num_trials_by_state[state]] + for state in sorted(trials_by_state) + } + return filtered_trials + + +def _get_trial_location(trial: Trial, result: dict) -> _Location: + # we get the location from the result, as the one in trial will be + # reset when trial terminates + node_ip, pid = result.get(NODE_IP, None), result.get(PID, None) + if node_ip and pid: + location = _Location(node_ip, pid) + else: + # fallback to trial location if there hasn't been a report yet + location = trial.temporary_state.location + return location + + +def _get_trial_info( + trial: Trial, parameters: List[str], metrics: List[str], max_column_length: int = 20 +): + """Returns the following information about a trial: + + name | status | loc | params... | metrics... + + Args: + trial: Trial to get information for. + parameters: Names of trial parameters to include. + metrics: Names of metrics to include. + max_column_length: Maximum column length (in characters). + """ + result = trial.last_result + config = trial.config + location = _get_trial_location(trial, result) + trial_info = [str(trial), trial.status, str(location)] + trial_info += [ + _max_len( + unflattened_lookup(param, config, default=None), + max_len=max_column_length, + add_addr=True, + ) + for param in parameters + ] + trial_info += [ + _max_len( + unflattened_lookup(metric, result, default=None), + max_len=max_column_length, + add_addr=True, + ) + for metric in metrics + ] + return trial_info + + +@DeveloperAPI +class TrialProgressCallback(Callback): + """Reports (prints) intermediate trial progress. + + This callback is automatically added to the callback stack. When a + result is obtained, this callback will print the results according to + the specified verbosity level. + + For ``Verbosity.V3_TRIAL_DETAILS``, a full result list is printed. + + For ``Verbosity.V2_TRIAL_NORM``, only one line is printed per received + result. + + All other verbosity levels do not print intermediate trial progress. + + Result printing is throttled on a per-trial basis. Per default, results are + printed only once every 30 seconds. Results are always printed when a trial + finished or errored. + + """ + + def __init__( + self, metric: Optional[str] = None, progress_metrics: Optional[List[str]] = None + ): + self._last_print = collections.defaultdict(float) + self._last_print_iteration = collections.defaultdict(int) + self._completed_trials = set() + self._last_result_str = {} + self._metric = metric + self._progress_metrics = set(progress_metrics or []) + + # Only use progress metrics if at least two metrics are in there + if self._metric and self._progress_metrics: + self._progress_metrics.add(self._metric) + self._last_result = {} + self._display_handle = None + + def _print(self, msg: str): + safe_print(msg) + + def on_trial_result( + self, + iteration: int, + trials: List["Trial"], + trial: "Trial", + result: Dict, + **info, + ): + self.log_result(trial, result, error=False) + + def on_trial_error( + self, iteration: int, trials: List["Trial"], trial: "Trial", **info + ): + self.log_result(trial, trial.last_result, error=True) + + def on_trial_complete( + self, iteration: int, trials: List["Trial"], trial: "Trial", **info + ): + # Only log when we never logged that a trial was completed + if trial not in self._completed_trials: + self._completed_trials.add(trial) + + print_result_str = self._print_result(trial.last_result) + last_result_str = self._last_result_str.get(trial, "") + # If this is a new result, print full result string + if print_result_str != last_result_str: + self.log_result(trial, trial.last_result, error=False) + else: + self._print(f"Trial {trial} completed. Last result: {print_result_str}") + + def log_result(self, trial: "Trial", result: Dict, error: bool = False): + done = result.get("done", False) is True + last_print = self._last_print[trial] + should_print = done or error or time.time() - last_print > DEBUG_PRINT_INTERVAL + + if done and trial not in self._completed_trials: + self._completed_trials.add(trial) + + if should_print: + if IS_NOTEBOOK: + self.display_result(trial, result, error, done) + else: + self.print_result(trial, result, error, done) + + self._last_print[trial] = time.time() + if TRAINING_ITERATION in result: + self._last_print_iteration[trial] = result[TRAINING_ITERATION] + + def print_result(self, trial: Trial, result: Dict, error: bool, done: bool): + """Print the most recent results for the given trial to stdout. + + Args: + trial: Trial for which results are to be printed + result: Result to be printed + error: True if an error has occurred, False otherwise + done: True if the trial is finished, False otherwise + """ + last_print_iteration = self._last_print_iteration[trial] + + if has_verbosity(Verbosity.V3_TRIAL_DETAILS): + if result.get(TRAINING_ITERATION) != last_print_iteration: + self._print(f"Result for {trial}:") + self._print(" {}".format(pretty_print(result).replace("\n", "\n "))) + if done: + self._print(f"Trial {trial} completed.") + + elif has_verbosity(Verbosity.V2_TRIAL_NORM): + metric_name = self._metric or "_metric" + metric_value = result.get(metric_name, -99.0) + error_file = Path(trial.local_path, EXPR_ERROR_FILE).as_posix() + + info = "" + if done: + info = " This trial completed." + + print_result_str = self._print_result(result) + + self._last_result_str[trial] = print_result_str + + if error: + message = ( + f"The trial {trial} errored with " + f"parameters={trial.config}. " + f"Error file: {error_file}" + ) + elif self._metric: + message = ( + f"Trial {trial} reported " + f"{metric_name}={metric_value:.2f} " + f"with parameters={trial.config}.{info}" + ) + else: + message = ( + f"Trial {trial} reported " + f"{print_result_str} " + f"with parameters={trial.config}.{info}" + ) + + self._print(message) + + def generate_trial_table( + self, trials: Dict[Trial, Dict], columns: List[str] + ) -> str: + """Generate an HTML table of trial progress info. + + Trials (rows) are sorted by name; progress stats (columns) are sorted + as well. + + Args: + trials: Trials and their associated latest results + columns: Columns to show in the table; must be a list of valid + keys for each Trial result + + Returns: + HTML template containing a rendered table of progress info + """ + data = [] + columns = sorted(columns) + + sorted_trials = collections.OrderedDict( + sorted(self._last_result.items(), key=lambda item: str(item[0])) + ) + for trial, result in sorted_trials.items(): + data.append([str(trial)] + [result.get(col, "") for col in columns]) + + return Template("trial_progress.html.j2").render( + table=tabulate( + data, tablefmt="html", headers=["Trial name"] + columns, showindex=False + ) + ) + + def display_result(self, trial: Trial, result: Dict, error: bool, done: bool): + """Display a formatted HTML table of trial progress results. + + Trial progress is only shown if verbosity is set to level 2 or 3. + + Args: + trial: Trial for which results are to be printed + result: Result to be printed + error: True if an error has occurred, False otherwise + done: True if the trial is finished, False otherwise + """ + from IPython.display import HTML, display + + self._last_result[trial] = result + if has_verbosity(Verbosity.V3_TRIAL_DETAILS): + ignored_keys = { + "config", + "hist_stats", + } + + elif has_verbosity(Verbosity.V2_TRIAL_NORM): + ignored_keys = { + "config", + "hist_stats", + "trial_id", + "experiment_tag", + "done", + } | set(AUTO_RESULT_KEYS) + else: + return + + table = self.generate_trial_table( + self._last_result, set(result.keys()) - ignored_keys + ) + if not self._display_handle: + self._display_handle = display(HTML(table), display_id=True) + else: + self._display_handle.update(HTML(table)) + + def _print_result(self, result: Dict): + if self._progress_metrics: + # If progress metrics are given, only report these + flat_result = flatten_dict(result) + + print_result = {} + for metric in self._progress_metrics: + print_result[metric] = flat_result.get(metric) + + else: + # Else, skip auto populated results + print_result = result.copy() + + for skip_result in SKIP_RESULTS_IN_REPORT: + print_result.pop(skip_result, None) + + for auto_result in AUTO_RESULT_KEYS: + print_result.pop(auto_result, None) + + print_result_str = ",".join( + [f"{k}={v}" for k, v in print_result.items() if v is not None] + ) + return print_result_str + + +def _detect_reporter(_trainer_api: bool = False, **kwargs) -> TuneReporterBase: + """Detect progress reporter class. + + Will return a :class:`JupyterNotebookReporter` if a IPython/Jupyter-like + session was detected, and a :class:`CLIReporter` otherwise. + + Keyword arguments are passed on to the reporter class. + """ + if IS_NOTEBOOK and not _trainer_api: + kwargs.setdefault("overwrite", not has_verbosity(Verbosity.V2_TRIAL_NORM)) + progress_reporter = JupyterNotebookReporter(**kwargs) + else: + progress_reporter = CLIReporter(**kwargs) + return progress_reporter + + +def _detect_progress_metrics( + trainable: Optional[Union["Trainable", Callable]] +) -> Optional[Collection[str]]: + """Detect progress metrics to report.""" + if not trainable: + return None + + return getattr(trainable, "_progress_metrics", None) + + +def _prepare_progress_reporter_for_ray_client( + progress_reporter: ProgressReporter, + verbosity: Union[int, Verbosity], + string_queue: Optional[Queue] = None, +) -> Tuple[ProgressReporter, Queue]: + """Prepares progress reported for Ray Client by setting the string queue. + + The string queue will be created if it's None.""" + set_verbosity(verbosity) + progress_reporter = progress_reporter or _detect_reporter() + + # JupyterNotebooks don't work with remote tune runs out of the box + # (e.g. via Ray client) as they don't have access to the main + # process stdout. So we introduce a queue here that accepts + # strings, which will then be displayed on the driver side. + if isinstance(progress_reporter, RemoteReporterMixin): + if string_queue is None: + string_queue = Queue( + actor_options={"num_cpus": 0, **_force_on_current_node(None)} + ) + progress_reporter.output_queue = string_queue + + return progress_reporter, string_queue + + +def _stream_client_output( + remote_future: ray.ObjectRef, + progress_reporter: ProgressReporter, + string_queue: Queue, +) -> Any: + """ + Stream items from string queue to progress_reporter until remote_future resolves + """ + if string_queue is None: + return + + def get_next_queue_item(): + try: + return string_queue.get(block=False) + except Empty: + return None + + def _handle_string_queue(): + string_item = get_next_queue_item() + while string_item is not None: + # This happens on the driver side + progress_reporter.display(string_item) + string_item = get_next_queue_item() + + # ray.wait(...)[1] returns futures that are not ready, yet + while ray.wait([remote_future], timeout=0.2)[1]: + # Check if we have items to execute + _handle_string_queue() + + # Handle queue one last time + _handle_string_queue() diff --git a/.venv/lib/python3.11/site-packages/ray/tune/registry.py b/.venv/lib/python3.11/site-packages/ray/tune/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..17eb93c41abf915994337a35361dcdd299bd96e7 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/tune/registry.py @@ -0,0 +1,314 @@ +import atexit +import logging +from functools import partial +from types import FunctionType +from typing import Callable, Optional, Type, Union + +import ray +import ray.cloudpickle as pickle +from ray.experimental.internal_kv import ( + _internal_kv_del, + _internal_kv_get, + _internal_kv_initialized, + _internal_kv_put, +) +from ray.tune.error import TuneError +from ray.util.annotations import DeveloperAPI + +TRAINABLE_CLASS = "trainable_class" +ENV_CREATOR = "env_creator" +RLLIB_MODEL = "rllib_model" +RLLIB_PREPROCESSOR = "rllib_preprocessor" +RLLIB_ACTION_DIST = "rllib_action_dist" +RLLIB_INPUT = "rllib_input" +RLLIB_CONNECTOR = "rllib_connector" +TEST = "__test__" +KNOWN_CATEGORIES = [ + TRAINABLE_CLASS, + ENV_CREATOR, + RLLIB_MODEL, + RLLIB_PREPROCESSOR, + RLLIB_ACTION_DIST, + RLLIB_INPUT, + RLLIB_CONNECTOR, + TEST, +] + +logger = logging.getLogger(__name__) + + +def _has_trainable(trainable_name): + return _global_registry.contains(TRAINABLE_CLASS, trainable_name) + + +@DeveloperAPI +def get_trainable_cls(trainable_name): + validate_trainable(trainable_name) + return _global_registry.get(TRAINABLE_CLASS, trainable_name) + + +@DeveloperAPI +def validate_trainable(trainable_name: str): + if not _has_trainable(trainable_name) and not _has_rllib_trainable(trainable_name): + raise TuneError(f"Unknown trainable: {trainable_name}") + + +def _has_rllib_trainable(trainable_name: str) -> bool: + try: + # Make sure everything rllib-related is registered. + from ray.rllib import _register_all + except (ImportError, ModuleNotFoundError): + return False + + _register_all() + return _has_trainable(trainable_name) + + +@DeveloperAPI +def is_function_trainable(trainable: Union[str, Callable, Type]) -> bool: + """Check if a given trainable is a function trainable. + Either the trainable has been wrapped as a FunctionTrainable class already, + or it's still a FunctionType/partial/callable.""" + from ray.tune.trainable import FunctionTrainable + + if isinstance(trainable, str): + trainable = get_trainable_cls(trainable) + + is_wrapped_func = isinstance(trainable, type) and issubclass( + trainable, FunctionTrainable + ) + return is_wrapped_func or ( + not isinstance(trainable, type) + and ( + isinstance(trainable, FunctionType) + or isinstance(trainable, partial) + or callable(trainable) + ) + ) + + +@DeveloperAPI +def register_trainable(name: str, trainable: Union[Callable, Type], warn: bool = True): + """Register a trainable function or class. + + This enables a class or function to be accessed on every Ray process + in the cluster. + + Args: + name: Name to register. + trainable: Function or tune.Trainable class. Functions must + take (config, status_reporter) as arguments and will be + automatically converted into a class during registration. + """ + + from ray.tune.trainable import Trainable, wrap_function + + if isinstance(trainable, type): + logger.debug("Detected class for trainable.") + elif isinstance(trainable, FunctionType) or isinstance(trainable, partial): + logger.debug("Detected function for trainable.") + trainable = wrap_function(trainable) + elif callable(trainable): + logger.info("Detected unknown callable for trainable. Converting to class.") + trainable = wrap_function(trainable) + + if not issubclass(trainable, Trainable): + raise TypeError("Second argument must be convertable to Trainable", trainable) + _global_registry.register(TRAINABLE_CLASS, name, trainable) + + +def _unregister_trainables(): + _global_registry.unregister_all(TRAINABLE_CLASS) + + +@DeveloperAPI +def register_env(name: str, env_creator: Callable): + """Register a custom environment for use with RLlib. + + This enables the environment to be accessed on every Ray process + in the cluster. + + Args: + name: Name to register. + env_creator: Callable that creates an env. + """ + + if not callable(env_creator): + raise TypeError("Second argument must be callable.", env_creator) + _global_registry.register(ENV_CREATOR, name, env_creator) + + +def _unregister_envs(): + _global_registry.unregister_all(ENV_CREATOR) + + +@DeveloperAPI +def register_input(name: str, input_creator: Callable): + """Register a custom input api for RLlib. + + Args: + name: Name to register. + input_creator: Callable that creates an + input reader. + """ + if not callable(input_creator): + raise TypeError("Second argument must be callable.", input_creator) + _global_registry.register(RLLIB_INPUT, name, input_creator) + + +def _unregister_inputs(): + _global_registry.unregister_all(RLLIB_INPUT) + + +@DeveloperAPI +def registry_contains_input(name: str) -> bool: + return _global_registry.contains(RLLIB_INPUT, name) + + +@DeveloperAPI +def registry_get_input(name: str) -> Callable: + return _global_registry.get(RLLIB_INPUT, name) + + +def _unregister_all(): + _unregister_inputs() + _unregister_envs() + _unregister_trainables() + + +def _check_serializability(key, value): + _global_registry.register(TEST, key, value) + + +def _make_key(prefix: str, category: str, key: str): + """Generate a binary key for the given category and key. + + Args: + prefix: Prefix + category: The category of the item + key: The unique identifier for the item + + Returns: + The key to use for storing a the value. + """ + return ( + b"TuneRegistry:" + + prefix.encode("ascii") + + b":" + + category.encode("ascii") + + b"/" + + key.encode("ascii") + ) + + +class _Registry: + def __init__(self, prefix: Optional[str] = None): + """If no prefix is given, use runtime context job ID.""" + self._to_flush = {} + self._prefix = prefix + self._registered = set() + self._atexit_handler_registered = False + + @property + def prefix(self): + if not self._prefix: + self._prefix = ray.get_runtime_context().get_job_id() + return self._prefix + + def _register_atexit(self): + if self._atexit_handler_registered: + # Already registered + return + + if ray._private.worker.global_worker.mode != ray.SCRIPT_MODE: + # Only cleanup on the driver + return + + atexit.register(_unregister_all) + self._atexit_handler_registered = True + + def register(self, category, key, value): + """Registers the value with the global registry. + + Raises: + PicklingError if unable to pickle to provided file. + """ + if category not in KNOWN_CATEGORIES: + from ray.tune import TuneError + + raise TuneError( + "Unknown category {} not among {}".format(category, KNOWN_CATEGORIES) + ) + self._to_flush[(category, key)] = pickle.dumps_debug(value) + if _internal_kv_initialized(): + self.flush_values() + + def unregister(self, category, key): + if _internal_kv_initialized(): + _internal_kv_del(_make_key(self.prefix, category, key)) + else: + self._to_flush.pop((category, key), None) + + def unregister_all(self, category: Optional[str] = None): + remaining = set() + for cat, key in self._registered: + if category and category == cat: + self.unregister(cat, key) + else: + remaining.add((cat, key)) + self._registered = remaining + + def contains(self, category, key): + if _internal_kv_initialized(): + value = _internal_kv_get(_make_key(self.prefix, category, key)) + return value is not None + else: + return (category, key) in self._to_flush + + def get(self, category, key): + if _internal_kv_initialized(): + value = _internal_kv_get(_make_key(self.prefix, category, key)) + if value is None: + raise ValueError( + "Registry value for {}/{} doesn't exist.".format(category, key) + ) + return pickle.loads(value) + else: + return pickle.loads(self._to_flush[(category, key)]) + + def flush_values(self): + self._register_atexit() + for (category, key), value in self._to_flush.items(): + _internal_kv_put( + _make_key(self.prefix, category, key), value, overwrite=True + ) + self._registered.add((category, key)) + self._to_flush.clear() + + +_global_registry = _Registry() +ray._private.worker._post_init_hooks.append(_global_registry.flush_values) + + +class _ParameterRegistry: + def __init__(self): + self.to_flush = {} + self.references = {} + + def put(self, k, v): + self.to_flush[k] = v + if ray.is_initialized(): + self.flush() + + def get(self, k): + if not ray.is_initialized(): + return self.to_flush[k] + return ray.get(self.references[k]) + + def flush(self): + for k, v in self.to_flush.items(): + if isinstance(v, ray.ObjectRef): + self.references[k] = v + else: + self.references[k] = ray.put(v) + self.to_flush.clear() diff --git a/.venv/lib/python3.11/site-packages/ray/tune/resources.py b/.venv/lib/python3.11/site-packages/ray/tune/resources.py new file mode 100644 index 0000000000000000000000000000000000000000..6c6113ceac03b8d554c8cdca962ddf23ec487c10 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/tune/resources.py @@ -0,0 +1,92 @@ +import json +import logging +from collections import namedtuple + +# For compatibility under py2 to consider unicode as str +from typing import Optional + +from ray.tune.error import TuneError +from ray.tune.execution.placement_groups import ( + PlacementGroupFactory, + resource_dict_to_pg_factory, +) +from ray.tune.utils.resource_updater import _Resources +from ray.util.annotations import Deprecated, DeveloperAPI + +logger = logging.getLogger(__name__) + + +@Deprecated +class Resources( + namedtuple( + "Resources", + [ + "cpu", + "gpu", + "memory", + "object_store_memory", + "extra_cpu", + "extra_gpu", + "extra_memory", + "extra_object_store_memory", + "custom_resources", + "extra_custom_resources", + "has_placement_group", + ], + ) +): + __slots__ = () + + def __new__( + cls, + cpu: float, + gpu: float, + memory: float = 0, + object_store_memory: float = 0.0, + extra_cpu: float = 0.0, + extra_gpu: float = 0.0, + extra_memory: float = 0.0, + extra_object_store_memory: float = 0.0, + custom_resources: Optional[dict] = None, + extra_custom_resources: Optional[dict] = None, + has_placement_group: bool = False, + ): + raise DeprecationWarning( + "tune.Resources is depracted. Use tune.PlacementGroupFactory instead." + ) + + +@DeveloperAPI +def json_to_resources(data: Optional[str]) -> Optional[PlacementGroupFactory]: + if data is None or data == "null": + return None + if isinstance(data, str): + data = json.loads(data) + + for k in data: + if k in ["driver_cpu_limit", "driver_gpu_limit"]: + raise TuneError( + "The field `{}` is no longer supported. Use `extra_cpu` " + "or `extra_gpu` instead.".format(k) + ) + if k not in _Resources._fields: + raise ValueError( + "Unknown resource field {}, must be one of {}".format( + k, Resources._fields + ) + ) + resource_dict_to_pg_factory( + dict( + cpu=data.get("cpu", 1), + gpu=data.get("gpu", 0), + memory=data.get("memory", 0), + custom_resources=data.get("custom_resources"), + ) + ) + + +@Deprecated +def resources_to_json(*args, **kwargs): + raise DeprecationWarning( + "tune.Resources is depracted. Use tune.PlacementGroupFactory instead." + ) diff --git a/.venv/lib/python3.11/site-packages/ray/tune/result.py b/.venv/lib/python3.11/site-packages/ray/tune/result.py new file mode 100644 index 0000000000000000000000000000000000000000..b4e966386a12743df99f3bb9aa243a07ce8b1409 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/tune/result.py @@ -0,0 +1,133 @@ +# Importing for Backward Compatibility +from ray.air.constants import ( # noqa: F401 + EXPR_ERROR_FILE, + EXPR_ERROR_PICKLE_FILE, + EXPR_PARAM_FILE, + EXPR_PARAM_PICKLE_FILE, + EXPR_PROGRESS_FILE, + EXPR_RESULT_FILE, + TIME_THIS_ITER_S, + TIMESTAMP, + TRAINING_ITERATION, +) + +# fmt: off +# __sphinx_doc_begin__ +# (Optional/Auto-filled) training is terminated. Filled only if not provided. +DONE = "done" + +# (Optional) Enum for user controlled checkpoint +SHOULD_CHECKPOINT = "should_checkpoint" + +# (Auto-filled) The hostname of the machine hosting the training process. +HOSTNAME = "hostname" + +# (Auto-filled) The auto-assigned id of the trial. +TRIAL_ID = "trial_id" + +# (Auto-filled) The auto-assigned id of the trial. +EXPERIMENT_TAG = "experiment_tag" + +# (Auto-filled) The node ip of the machine hosting the training process. +NODE_IP = "node_ip" + +# (Auto-filled) The pid of the training process. +PID = "pid" + +# (Optional) Default (anonymous) metric when using tune.report(x) +DEFAULT_METRIC = "_metric" + +# (Optional) Mean reward for current training iteration +EPISODE_REWARD_MEAN = "episode_reward_mean" + +# (Optional) Mean loss for training iteration +MEAN_LOSS = "mean_loss" + +# (Optional) Mean accuracy for training iteration +MEAN_ACCURACY = "mean_accuracy" + +# Number of episodes in this iteration. +EPISODES_THIS_ITER = "episodes_this_iter" + +# (Optional/Auto-filled) Accumulated number of episodes for this trial. +EPISODES_TOTAL = "episodes_total" + +# Number of timesteps in this iteration. +TIMESTEPS_THIS_ITER = "timesteps_this_iter" + +# (Auto-filled) Accumulated number of timesteps for this entire trial. +TIMESTEPS_TOTAL = "timesteps_total" + +# (Auto-filled) Accumulated time in seconds for this entire trial. +TIME_TOTAL_S = "time_total_s" + +# __sphinx_doc_end__ +# fmt: on + +DEFAULT_EXPERIMENT_INFO_KEYS = ("trainable_name", EXPERIMENT_TAG, TRIAL_ID) + +DEFAULT_RESULT_KEYS = ( + TRAINING_ITERATION, + TIME_TOTAL_S, + MEAN_ACCURACY, + MEAN_LOSS, +) + +# Metrics that don't require at least one iteration to complete +DEBUG_METRICS = ( + TRIAL_ID, + "experiment_id", + "date", + TIMESTAMP, + PID, + HOSTNAME, + NODE_IP, + "config", +) + +# Make sure this doesn't regress +AUTO_RESULT_KEYS = ( + TRAINING_ITERATION, + TIME_TOTAL_S, + EPISODES_TOTAL, + TIMESTEPS_TOTAL, + NODE_IP, + HOSTNAME, + PID, + TIME_TOTAL_S, + TIME_THIS_ITER_S, + TIMESTAMP, + "date", + "time_since_restore", + "timesteps_since_restore", + "iterations_since_restore", + "config", + # TODO(justinvyu): Move this stuff to train to avoid cyclical dependency. + "checkpoint_dir_name", +) + +# __duplicate__ is a magic keyword used internally to +# avoid double-logging results when using the Function API. +RESULT_DUPLICATE = "__duplicate__" + +# __trial_info__ is a magic keyword used internally to pass trial_info +# to the Trainable via the constructor. +TRIAL_INFO = "__trial_info__" + +# __stdout_file__/__stderr_file__ are magic keywords used internally +# to pass log file locations to the Trainable via the constructor. +STDOUT_FILE = "__stdout_file__" +STDERR_FILE = "__stderr_file__" + +DEFAULT_EXPERIMENT_NAME = "default" + +# Meta file about status under each experiment directory, can be +# parsed by automlboard if exists. +JOB_META_FILE = "job_status.json" + +# Meta file about status under each trial directory, can be parsed +# by automlboard if exists. +EXPR_META_FILE = "trial_status.json" + +# Config prefix when using ExperimentAnalysis. +CONFIG_PREFIX = "config" diff --git a/.venv/lib/python3.11/site-packages/ray/tune/result_grid.py b/.venv/lib/python3.11/site-packages/ray/tune/result_grid.py new file mode 100644 index 0000000000000000000000000000000000000000..7dffe6d4614ab6e6d4f09c722af99bac6f5c971e --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/tune/result_grid.py @@ -0,0 +1,285 @@ +from typing import Optional, Union + +import pandas as pd +import pyarrow + +from ray.air.result import Result +from ray.exceptions import RayTaskError +from ray.tune.analysis import ExperimentAnalysis +from ray.tune.error import TuneError +from ray.tune.experiment import Trial +from ray.util import PublicAPI + + +@PublicAPI(stability="beta") +class ResultGrid: + """A set of ``Result`` objects for interacting with Ray Tune results. + + You can use it to inspect the trials and obtain the best result. + + The constructor is a private API. This object can only be created as a result of + ``Tuner.fit()``. + + Example: + .. testcode:: + + import random + from ray import train, tune + def random_error_trainable(config): + if random.random() < 0.5: + return {"loss": 0.0} + else: + raise ValueError("This is an error") + tuner = tune.Tuner( + random_error_trainable, + run_config=train.RunConfig(name="example-experiment"), + tune_config=tune.TuneConfig(num_samples=10), + ) + try: + result_grid = tuner.fit() + except ValueError: + pass + for i in range(len(result_grid)): + result = result_grid[i] + if not result.error: + print(f"Trial finishes successfully with metrics" + f"{result.metrics}.") + else: + print(f"Trial failed with error {result.error}.") + + .. testoutput:: + :hide: + + ... + + You can also use ``result_grid`` for more advanced analysis. + + >>> # Get the best result based on a particular metric. + >>> best_result = result_grid.get_best_result( # doctest: +SKIP + ... metric="loss", mode="min") + >>> # Get the best checkpoint corresponding to the best result. + >>> best_checkpoint = best_result.checkpoint # doctest: +SKIP + >>> # Get a dataframe for the last reported results of all of the trials + >>> df = result_grid.get_dataframe() # doctest: +SKIP + >>> # Get a dataframe for the minimum loss seen for each trial + >>> df = result_grid.get_dataframe(metric="loss", mode="min") # doctest: +SKIP + + Note that trials of all statuses are included in the final result grid. + If a trial is not in terminated state, its latest result and checkpoint as + seen by Tune will be provided. + + See :doc:`/tune/examples/tune_analyze_results` for more usage examples. + """ + + def __init__( + self, + experiment_analysis: ExperimentAnalysis, + ): + self._experiment_analysis = experiment_analysis + self._results = [ + self._trial_to_result(trial) for trial in self._experiment_analysis.trials + ] + + @property + def experiment_path(self) -> str: + """Path pointing to the experiment directory on persistent storage. + + This can point to a remote storage location (e.g. S3) or to a local + location (path on the head node).""" + return self._experiment_analysis.experiment_path + + @property + def filesystem(self) -> pyarrow.fs.FileSystem: + """Return the filesystem that can be used to access the experiment path. + + Returns: + pyarrow.fs.FileSystem implementation. + """ + return self._experiment_analysis._fs + + def get_best_result( + self, + metric: Optional[str] = None, + mode: Optional[str] = None, + scope: str = "last", + filter_nan_and_inf: bool = True, + ) -> Result: + """Get the best result from all the trials run. + + Args: + metric: Key for trial info to order on. Defaults to + the metric specified in your Tuner's ``TuneConfig``. + mode: One of [min, max]. Defaults to the mode specified + in your Tuner's ``TuneConfig``. + scope: One of [all, last, avg, last-5-avg, last-10-avg]. + If `scope=last`, only look at each trial's final step for + `metric`, and compare across trials based on `mode=[min,max]`. + If `scope=avg`, consider the simple average over all steps + for `metric` and compare across trials based on + `mode=[min,max]`. If `scope=last-5-avg` or `scope=last-10-avg`, + consider the simple average over the last 5 or 10 steps for + `metric` and compare across trials based on `mode=[min,max]`. + If `scope=all`, find each trial's min/max score for `metric` + based on `mode`, and compare trials based on `mode=[min,max]`. + filter_nan_and_inf: If True (default), NaN or infinite + values are disregarded and these trials are never selected as + the best trial. + """ + if len(self._experiment_analysis.trials) == 1: + return self._trial_to_result(self._experiment_analysis.trials[0]) + if not metric and not self._experiment_analysis.default_metric: + raise ValueError( + "No metric is provided. Either pass in a `metric` arg to " + "`get_best_result` or specify a metric in the " + "`TuneConfig` of your `Tuner`." + ) + if not mode and not self._experiment_analysis.default_mode: + raise ValueError( + "No mode is provided. Either pass in a `mode` arg to " + "`get_best_result` or specify a mode in the " + "`TuneConfig` of your `Tuner`." + ) + + best_trial = self._experiment_analysis.get_best_trial( + metric=metric, + mode=mode, + scope=scope, + filter_nan_and_inf=filter_nan_and_inf, + ) + if not best_trial: + error_msg = ( + "No best trial found for the given metric: " + f"{metric or self._experiment_analysis.default_metric}. " + "This means that no trial has reported this metric" + ) + error_msg += ( + ", or all values reported for this metric are NaN. To not ignore NaN " + "values, you can set the `filter_nan_and_inf` arg to False." + if filter_nan_and_inf + else "." + ) + raise RuntimeError(error_msg) + + return self._trial_to_result(best_trial) + + def get_dataframe( + self, + filter_metric: Optional[str] = None, + filter_mode: Optional[str] = None, + ) -> pd.DataFrame: + """Return dataframe of all trials with their configs and reported results. + + Per default, this returns the last reported results for each trial. + + If ``filter_metric`` and ``filter_mode`` are set, the results from each + trial are filtered for this metric and mode. For example, if + ``filter_metric="some_metric"`` and ``filter_mode="max"``, for each trial, + every received result is checked, and the one where ``some_metric`` is + maximal is returned. + + + Example: + + .. testcode:: + + from ray import train + from ray.train import RunConfig + from ray.tune import Tuner + + def training_loop_per_worker(config): + train.report({"accuracy": 0.8}) + + result_grid = Tuner( + trainable=training_loop_per_worker, + run_config=RunConfig(name="my_tune_run") + ).fit() + + # Get last reported results per trial + df = result_grid.get_dataframe() + + # Get best ever reported accuracy per trial + df = result_grid.get_dataframe( + filter_metric="accuracy", filter_mode="max" + ) + + .. testoutput:: + :hide: + + ... + + Args: + filter_metric: Metric to filter best result for. + filter_mode: If ``filter_metric`` is given, one of ``["min", "max"]`` + to specify if we should find the minimum or maximum result. + + Returns: + Pandas DataFrame with each trial as a row and their results as columns. + """ + return self._experiment_analysis.dataframe( + metric=filter_metric, mode=filter_mode + ) + + def __len__(self) -> int: + return len(self._results) + + def __getitem__(self, i: int) -> Result: + """Returns the i'th result in the grid.""" + return self._results[i] + + @property + def errors(self): + """Returns the exceptions of errored trials.""" + return [result.error for result in self if result.error] + + @property + def num_errors(self): + """Returns the number of errored trials.""" + return len( + [t for t in self._experiment_analysis.trials if t.status == Trial.ERROR] + ) + + @property + def num_terminated(self): + """Returns the number of terminated (but not errored) trials.""" + return len( + [ + t + for t in self._experiment_analysis.trials + if t.status == Trial.TERMINATED + ] + ) + + @staticmethod + def _populate_exception(trial: Trial) -> Optional[Union[TuneError, RayTaskError]]: + if trial.status == Trial.TERMINATED: + return None + return trial.get_pickled_error() or trial.get_error() + + def _trial_to_result(self, trial: Trial) -> Result: + cpm = trial.run_metadata.checkpoint_manager + checkpoint = None + if cpm.latest_checkpoint_result: + checkpoint = cpm.latest_checkpoint_result.checkpoint + best_checkpoint_results = cpm.best_checkpoint_results + best_checkpoints = [ + (checkpoint_result.checkpoint, checkpoint_result.metrics) + for checkpoint_result in best_checkpoint_results + ] + + metrics_df = self._experiment_analysis.trial_dataframes.get(trial.trial_id) + + result = Result( + checkpoint=checkpoint, + metrics=trial.last_result.copy(), + error=self._populate_exception(trial), + path=trial.path, + _storage_filesystem=self._experiment_analysis._fs, + metrics_dataframe=metrics_df, + best_checkpoints=best_checkpoints, + ) + return result + + def __repr__(self) -> str: + all_results_repr = [result._repr(indent=2) for result in self] + all_results_repr = ",\n".join(all_results_repr) + return f"ResultGrid<[\n{all_results_repr}\n]>" diff --git a/.venv/lib/python3.11/site-packages/ray/tune/syncer.py b/.venv/lib/python3.11/site-packages/ray/tune/syncer.py new file mode 100644 index 0000000000000000000000000000000000000000..db1ca47e12dbe488a58a2aeef8d144f52e9209ec --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/tune/syncer.py @@ -0,0 +1,15 @@ +import logging + +from ray.train._internal.syncer import SyncConfig as TrainSyncConfig +from ray.util.annotations import Deprecated + +logger = logging.getLogger(__name__) + + +@Deprecated +class SyncConfig(TrainSyncConfig): + def __new__(cls: type, *args, **kwargs): + raise DeprecationWarning( + "`ray.tune.SyncConfig` has been moved to `ray.train.SyncConfig`. " + "Please update your code to use `ray.train.SyncConfig`." + ) diff --git a/.venv/lib/python3.11/site-packages/ray/tune/trainable/__init__.py b/.venv/lib/python3.11/site-packages/ray/tune/trainable/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a308feb13422e477bb5c149ab6082eb2c3e7f0c4 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/tune/trainable/__init__.py @@ -0,0 +1,10 @@ +from ray.tune.trainable.function_trainable import FunctionTrainable, wrap_function +from ray.tune.trainable.trainable import Trainable +from ray.tune.trainable.util import with_parameters + +__all__ = [ + "Trainable", + "FunctionTrainable", + "with_parameters", + "wrap_function", +] diff --git a/.venv/lib/python3.11/site-packages/ray/tune/trainable/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/tune/trainable/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..950458ea158d1d0d3f5eae84845896ed43825979 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/tune/trainable/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/tune/trainable/__pycache__/function_trainable.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/tune/trainable/__pycache__/function_trainable.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4929a33fdd57b1aa7bbe82f8dae53b6ad2007954 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/tune/trainable/__pycache__/function_trainable.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/tune/trainable/__pycache__/metadata.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/tune/trainable/__pycache__/metadata.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a129027b3a4fc0ab2522d9a554274a6bc9297ba1 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/tune/trainable/__pycache__/metadata.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/tune/trainable/__pycache__/trainable.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/tune/trainable/__pycache__/trainable.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5633d94aacf68b6115054d955986dda2064bb27b Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/tune/trainable/__pycache__/trainable.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/tune/trainable/__pycache__/trainable_fn_utils.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/tune/trainable/__pycache__/trainable_fn_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7a9b2ba056984776edd86ea16afd475ab6696711 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/tune/trainable/__pycache__/trainable_fn_utils.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/tune/trainable/__pycache__/util.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/tune/trainable/__pycache__/util.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b84858a1f794a1216669dd2ec5fab1ef822bd5bd Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/tune/trainable/__pycache__/util.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/tune/trainable/function_trainable.py b/.venv/lib/python3.11/site-packages/ray/tune/trainable/function_trainable.py new file mode 100644 index 0000000000000000000000000000000000000000..9dc9ff02cbfd6976eedb7d9db4b20d0bceed161e --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/tune/trainable/function_trainable.py @@ -0,0 +1,278 @@ +import inspect +import logging +import os +import queue +from functools import partial +from numbers import Number +from typing import Any, Callable, Dict, Optional, Type + +from ray.air._internal.util import RunnerThread, StartTraceback +from ray.air.constants import _ERROR_FETCH_TIMEOUT +from ray.train._internal.checkpoint_manager import _TrainingResult +from ray.train._internal.session import ( + TrialInfo, + _TrainSession, + get_session, + init_session, + shutdown_session, +) +from ray.train.v2._internal.constants import RUN_CONTROLLER_AS_ACTOR_ENV_VAR +from ray.tune.execution.placement_groups import PlacementGroupFactory +from ray.tune.result import DEFAULT_METRIC, RESULT_DUPLICATE, SHOULD_CHECKPOINT +from ray.tune.trainable.trainable import Trainable +from ray.tune.utils import _detect_config_single +from ray.util.annotations import DeveloperAPI + +logger = logging.getLogger(__name__) + +# Time between FunctionTrainable checks when fetching +# new results after signaling the reporter to continue + +NULL_MARKER = ".null_marker" +TEMP_MARKER = ".temp_marker" + + +@DeveloperAPI +class FunctionTrainable(Trainable): + """Trainable that runs a user function reporting results. + + This mode of execution does not support checkpoint/restore.""" + + _name = "func" + + def setup(self, config): + init_session( + training_func=lambda: self._trainable_func(self.config), + trial_info=TrialInfo( + name=self.trial_name, + id=self.trial_id, + resources=self.trial_resources, + logdir=self._storage.trial_driver_staging_path, + driver_ip=None, + driver_node_id=None, + experiment_name=self._storage.experiment_dir_name, + ), + storage=self._storage, + synchronous_result_reporting=True, + # Set all Train-specific properties to None. + world_rank=None, + local_rank=None, + node_rank=None, + local_world_size=None, + world_size=None, + dataset_shard=None, + checkpoint=None, + ) + self._last_training_result: Optional[_TrainingResult] = None + + # NOTE: This environment variable is used to disable the + # spawning a new actor for Ray Train drivers being launched + # within Tune functions. + # There are 2 reasons for this: + # 1. Ray Tune already spawns an actor, so we can run the Ray Train + # driver directly in the same actor. + # 2. This allows `ray.tune.report` to be called within Ray Train driver + # callbacks, since it needs to be called on the same process as the + # Tune FunctionTrainable actor. + os.environ[RUN_CONTROLLER_AS_ACTOR_ENV_VAR] = "0" + + def _trainable_func(self, config: Dict[str, Any]): + """Subclasses can override this to set the trainable func.""" + + raise NotImplementedError + + def _start(self): + def entrypoint(): + try: + return self._trainable_func(self.config) + except Exception as e: + raise StartTraceback from e + + # the runner thread is not started until the first call to _train + self._runner = RunnerThread( + target=entrypoint, error_queue=self._error_queue, daemon=True + ) + # if not alive, try to start + self._status_reporter._start() + try: + self._runner.start() + except RuntimeError: + # If this is reached, it means the thread was started and is + # now done or has raised an exception. + pass + + def step(self): + """Implements train() for a Function API. + + If the RunnerThread finishes without reporting "done", + Tune will automatically provide a magic keyword __duplicate__ + along with a result with "done=True". The TrialRunner will handle the + result accordingly (see tune/tune_controller.py). + """ + session: _TrainSession = get_session() + if not session.training_started: + session.start() + + training_result: Optional[_TrainingResult] = session.get_next() + + if not training_result: + # The `RESULT_DUPLICATE` result should have been the last + # result reported by the session, which triggers cleanup. + raise RuntimeError( + "Should not have reached here. The TuneController should not " + "have scheduled another `train` remote call." + "It should have scheduled a `stop` instead " + "after the training function exits." + ) + + metrics = training_result.metrics + # This keyword appears if the train_func using the Function API + # finishes without "done=True". This duplicates the last result, but + # the TuneController will not log this result again. + # TuneController will also inject done=True to the result, + # and proceed to queue up a STOP decision for the trial. + if RESULT_DUPLICATE in metrics: + metrics[SHOULD_CHECKPOINT] = False + + self._last_training_result = training_result + if training_result.checkpoint is not None: + # TODO(justinvyu): Result/checkpoint reporting can be combined. + # For now, since result/checkpoint reporting is separate, this + # special key will tell Tune to pull the checkpoint from + # the `last_training_result`. + metrics[SHOULD_CHECKPOINT] = True + return metrics + + def execute(self, fn): + return fn(self) + + def save_checkpoint(self, checkpoint_dir: str = ""): + if checkpoint_dir: + raise ValueError("Checkpoint dir should not be used with function API.") + + # TODO(justinvyu): This currently breaks the `save_checkpoint` interface. + # TRAIN -> SAVE remote calls get processed sequentially, + # so `_last_training_result.checkpoint` holds onto the latest ckpt. + return self._last_training_result + + def load_checkpoint(self, checkpoint_result: _TrainingResult): + # TODO(justinvyu): This currently breaks the `load_checkpoint` interface. + session = get_session() + session.loaded_checkpoint = checkpoint_result.checkpoint + + def cleanup(self): + session = get_session() + try: + # session.finish raises any Exceptions from training. + # Do not wait for thread termination here (timeout=0). + session.finish(timeout=0) + finally: + # Check for any errors that might have been missed. + session._report_thread_runner_error() + # Shutdown session even if session.finish() raises an Exception. + shutdown_session() + + def reset_config(self, new_config): + session = get_session() + + # Wait for thread termination so it is save to re-use the same actor. + thread_timeout = int(os.environ.get("TUNE_FUNCTION_THREAD_TIMEOUT_S", 2)) + session.finish(timeout=thread_timeout) + if session.training_thread.is_alive(): + # Did not finish within timeout, reset unsuccessful. + return False + + session.reset( + training_func=lambda: self._trainable_func(self.config), + trial_info=TrialInfo( + name=self.trial_name, + id=self.trial_id, + resources=self.trial_resources, + logdir=self._storage.trial_working_directory, + driver_ip=None, + driver_node_id=None, + experiment_name=self._storage.experiment_dir_name, + ), + storage=self._storage, + ) + + self._last_result = {} + return True + + def _report_thread_runner_error(self, block=False): + try: + e = self._error_queue.get(block=block, timeout=_ERROR_FETCH_TIMEOUT) + raise StartTraceback from e + except queue.Empty: + pass + + +@DeveloperAPI +def wrap_function( + train_func: Callable[[Any], Any], name: Optional[str] = None +) -> Type["FunctionTrainable"]: + inherit_from = (FunctionTrainable,) + + if hasattr(train_func, "__mixins__"): + inherit_from = train_func.__mixins__ + inherit_from + + func_args = inspect.getfullargspec(train_func).args + use_config_single = _detect_config_single(train_func) + + if not use_config_single: + raise ValueError( + "Unknown argument found in the Trainable function. " + "The function args must include a single 'config' positional parameter.\n" + "Found: {}".format(func_args) + ) + + resources = getattr(train_func, "_resources", None) + + class ImplicitFunc(*inherit_from): + _name = name or ( + train_func.__name__ if hasattr(train_func, "__name__") else "func" + ) + + def __repr__(self): + return self._name + + def _trainable_func(self, config): + fn = partial(train_func, config) + + def handle_output(output): + if not output: + return + elif isinstance(output, dict): + get_session().report(output) + elif isinstance(output, Number): + get_session().report({DEFAULT_METRIC: output}) + else: + raise ValueError( + "Invalid return or yield value. Either return/yield " + "a single number or a dictionary object in your " + "trainable function." + ) + + output = None + if inspect.isgeneratorfunction(train_func): + for output in fn(): + handle_output(output) + else: + output = fn() + handle_output(output) + + # If train_func returns, we need to notify the main event loop + # of the last result while avoiding double logging. This is done + # with the keyword RESULT_DUPLICATE -- see tune/tune_controller.py. + get_session().report({RESULT_DUPLICATE: True}) + return output + + @classmethod + def default_resource_request( + cls, config: Dict[str, Any] + ) -> Optional[PlacementGroupFactory]: + if not isinstance(resources, PlacementGroupFactory) and callable(resources): + return resources(config) + return resources + + return ImplicitFunc diff --git a/.venv/lib/python3.11/site-packages/ray/tune/trainable/metadata.py b/.venv/lib/python3.11/site-packages/ray/tune/trainable/metadata.py new file mode 100644 index 0000000000000000000000000000000000000000..a520371e48503b95a64c4449c48510d960e78ef3 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/tune/trainable/metadata.py @@ -0,0 +1,102 @@ +import json +from collections import deque +from numbers import Number +from typing import Optional, Tuple + +from ray.train._internal.checkpoint_manager import _CheckpointManager +from ray.tune.utils.serialization import TuneFunctionDecoder, TuneFunctionEncoder + + +class _TrainingRunMetadata: + """Serializable struct for holding runtime trial metadata. + + Runtime metadata is data that changes and is updated on runtime. This includes + e.g. the last result, the currently available checkpoints, and the number + of errors encountered for a trial. + """ + + def __init__(self, n_steps: Tuple[int] = (5, 10)): + # General metadata + self.start_time = None + + # Errors + self.num_failures = 0 + self.num_failures_after_restore = 0 + + self.error_filename = None + self.pickled_error_filename = None + + # Results and metrics + self.last_result = {} + self.last_result_time = -float("inf") + + # stores in memory max/min/avg/last-n-avg/last result for each + # metric by trial + self.metric_analysis = {} + self._n_steps = n_steps + self.metric_n_steps = {} + + # Checkpoints + self.checkpoint_manager: Optional[_CheckpointManager] = None + + self._cached_json = None + + def invalidate_cache(self): + self._cached_json = None + + def update_metric(self, metric: str, value: Number, step: Optional[int] = 1): + if metric not in self.metric_analysis: + self.metric_analysis[metric] = { + "max": value, + "min": value, + "avg": value, + "last": value, + } + self.metric_n_steps[metric] = {} + for n in self._n_steps: + key = "last-{:d}-avg".format(n) + self.metric_analysis[metric][key] = value + # Store n as string for correct restore. + self.metric_n_steps[metric][str(n)] = deque([value], maxlen=n) + else: + step = step or 1 + self.metric_analysis[metric]["max"] = max( + value, self.metric_analysis[metric]["max"] + ) + self.metric_analysis[metric]["min"] = min( + value, self.metric_analysis[metric]["min"] + ) + self.metric_analysis[metric]["avg"] = ( + 1 / step * (value + (step - 1) * self.metric_analysis[metric]["avg"]) + ) + self.metric_analysis[metric]["last"] = value + + for n in self._n_steps: + key = "last-{:d}-avg".format(n) + self.metric_n_steps[metric][str(n)].append(value) + self.metric_analysis[metric][key] = sum( + self.metric_n_steps[metric][str(n)] + ) / len(self.metric_n_steps[metric][str(n)]) + self.invalidate_cache() + + def __setattr__(self, key, value): + super().__setattr__(key, value) + if key not in {"_cached_json"}: + self.invalidate_cache() + + def get_json_state(self) -> str: + if self._cached_json is None: + data = self.__dict__ + data.pop("_cached_json", None) + self._cached_json = json.dumps(data, indent=2, cls=TuneFunctionEncoder) + + return self._cached_json + + @classmethod + def from_json_state(cls, json_state: str) -> "_TrainingRunMetadata": + state = json.loads(json_state, cls=TuneFunctionDecoder) + + run_metadata = cls() + run_metadata.__dict__.update(state) + + return run_metadata diff --git a/.venv/lib/python3.11/site-packages/ray/tune/trainable/trainable.py b/.venv/lib/python3.11/site-packages/ray/tune/trainable/trainable.py new file mode 100644 index 0000000000000000000000000000000000000000..0714e8b6778ae92b2f52d3ddf528c5f3b7ebb0ba --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/tune/trainable/trainable.py @@ -0,0 +1,995 @@ +import copy +import logging +import os +import platform +import shutil +import sys +import tempfile +import time +from contextlib import redirect_stderr, redirect_stdout +from datetime import datetime +from pathlib import Path +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union + +import ray +import ray.cloudpickle as ray_pickle +from ray.air._internal.util import exception_cause, skip_exceptions +from ray.air.constants import TIME_THIS_ITER_S, TIMESTAMP, TRAINING_ITERATION +from ray.train import Checkpoint +from ray.train._internal.checkpoint_manager import _TrainingResult +from ray.train._internal.storage import StorageContext, _exists_at_fs_path +from ray.train.constants import DEFAULT_STORAGE_PATH +from ray.tune.execution.placement_groups import PlacementGroupFactory +from ray.tune.result import ( + DEBUG_METRICS, + DONE, + EPISODES_THIS_ITER, + EPISODES_TOTAL, + HOSTNAME, + NODE_IP, + PID, + RESULT_DUPLICATE, + SHOULD_CHECKPOINT, + STDERR_FILE, + STDOUT_FILE, + TIME_TOTAL_S, + TIMESTEPS_THIS_ITER, + TIMESTEPS_TOTAL, + TRIAL_ID, + TRIAL_INFO, +) +from ray.tune.utils import UtilMonitor +from ray.tune.utils.log import disable_ipython +from ray.tune.utils.util import Tee +from ray.util.annotations import DeveloperAPI, PublicAPI + +if TYPE_CHECKING: + from ray.tune.logger import Logger + +logger = logging.getLogger(__name__) + +SETUP_TIME_THRESHOLD = 10 + +# File containing dict data returned by user from `Trainable.save_checkpoint` +_DICT_CHECKPOINT_FILE_NAME = "_dict_checkpoint.pkl" + + +@PublicAPI +class Trainable: + """Abstract class for trainable models, functions, etc. + + A call to ``train()`` on a trainable will execute one logical iteration of + training. As a rule of thumb, the execution time of one train call should + be large enough to avoid overheads (i.e. more than a few seconds), but + short enough to report progress periodically (i.e. at most a few minutes). + + Calling ``save()`` should save the training state of a trainable to disk, + and ``restore(path)`` should restore a trainable to the given state. + + Generally you only need to implement ``setup``, ``step``, + ``save_checkpoint``, and ``load_checkpoint`` when subclassing Trainable. + + Other implementation methods that may be helpful to override are + ``log_result``, ``reset_config``, ``cleanup``, and ``_export_model``. + + Tune will convert this class into a Ray actor, which runs on a separate process. + By default, Tune will also change the current working directory of this process to + its corresponding trial-level log directory ``self.logdir``. + This is designed so that different trials that run on the same physical node won't + accidentally write to the same location and overstep each other. + + The behavior of changing the working directory can be disabled by setting the + `RAY_CHDIR_TO_TRIAL_DIR=0` environment variable. This allows access to files + in the original working directory, but relative paths should be used for read only + purposes, and you must make sure that the directory is synced on all nodes if + running on multiple machines. + + The `TUNE_ORIG_WORKING_DIR` environment variable was the original workaround for + accessing paths relative to the original working directory. This environment + variable is deprecated, and the `RAY_CHDIR_TO_TRIAL_DIR` environment variable + described above should be used instead. + + This class supports checkpointing to and restoring from remote storage. + """ + + def __init__( + self, + config: Dict[str, Any] = None, + logger_creator: Callable[[Dict[str, Any]], "Logger"] = None, # Deprecated (2.7) + storage: Optional[StorageContext] = None, + ): + """Initialize a Trainable. + + Sets up logging and points ``self.logdir`` to a directory in which + training outputs should be placed. + + Subclasses should prefer defining ``setup()`` instead of overriding + ``__init__()`` directly. + + Args: + config: Trainable-specific configuration data. By default + will be saved as ``self.config``. + logger_creator: (Deprecated) Function that creates a ray.tune.Logger + object. If unspecified, a default logger is created. + storage: StorageContext object that contains persistent storage paths + """ + + self.config = config or {} + trial_info = self.config.pop(TRIAL_INFO, None) + + if self.is_actor(): + disable_ipython() + + # TODO(ml-team): Remove `logger_creator` in 2.7. + # TODO(justinvyu): Rename/remove logdir. + self._result_logger = self._logdir = None + self._create_logger(self.config, logger_creator) + + self._stdout_context = self._stdout_fp = self._stdout_stream = None + self._stderr_context = self._stderr_fp = self._stderr_stream = None + self._stderr_logging_handler = None + + stdout_file = self.config.pop(STDOUT_FILE, None) + stderr_file = self.config.pop(STDERR_FILE, None) + + self._iteration = 0 + self._time_total = 0.0 + self._timesteps_total = None + self._episodes_total = None + self._time_since_restore = 0.0 + self._timesteps_since_restore = 0 + self._iterations_since_restore = 0 + self._last_result = None + self._restored = False + self._trial_info = trial_info + self._stdout_file = stdout_file + self._stderr_file = stderr_file + + self._start_time = time.time() + self._local_ip = ray.util.get_node_ip_address() + + self._storage = storage + if storage: + assert storage.trial_fs_path + logger.debug(f"StorageContext on the TRAINABLE:\n{storage}") + + self._open_logfiles(stdout_file, stderr_file) + + self.setup(copy.deepcopy(self.config)) + setup_time = time.time() - self._start_time + if setup_time > SETUP_TIME_THRESHOLD: + logger.info( + "Trainable.setup took {:.3f} seconds. If your " + "trainable is slow to initialize, consider setting " + "reuse_actors=True to reduce actor creation " + "overheads.".format(setup_time) + ) + log_sys_usage = self.config.get("log_sys_usage", False) + self._monitor = UtilMonitor(start=log_sys_usage) + + @classmethod + def default_resource_request( + cls, config: Dict[str, Any] + ) -> Optional[PlacementGroupFactory]: + """Provides a static resource requirement for the given configuration. + + This can be overridden by sub-classes to set the correct trial resource + allocation, so the user does not need to. + + .. testcode:: + + @classmethod + def default_resource_request(cls, config): + return PlacementGroupFactory([{"CPU": 1}, {"CPU": 1}]) + + + Args: + config[Dict[str, Any]]: The Trainable's config dict. + + Returns: + PlacementGroupFactory: A PlacementGroupFactory consumed by Tune + for queueing. + """ + return None + + @classmethod + def resource_help(cls, config: Dict): + """Returns a help string for configuring this trainable's resources. + + Args: + config: The Trainer's config dict. + """ + return "" + + def get_current_ip_pid(self): + return self._local_ip, os.getpid() + + def get_auto_filled_metrics( + self, + now: Optional[datetime] = None, + time_this_iter: Optional[float] = None, + timestamp: Optional[int] = None, + debug_metrics_only: bool = False, + ) -> dict: + """Return a dict with metrics auto-filled by the trainable. + + If ``debug_metrics_only`` is True, only metrics that don't + require at least one iteration will be returned + (``ray.tune.result.DEBUG_METRICS``). + """ + if now is None: + now = datetime.today() + autofilled = { + TRIAL_ID: self.trial_id, + "date": now.strftime("%Y-%m-%d_%H-%M-%S"), + "timestamp": timestamp if timestamp else int(time.mktime(now.timetuple())), + TIME_THIS_ITER_S: time_this_iter, + TIME_TOTAL_S: self._time_total, + PID: os.getpid(), + HOSTNAME: platform.node(), + NODE_IP: self._local_ip, + "config": self.config, + "time_since_restore": self._time_since_restore, + "iterations_since_restore": self._iterations_since_restore, + } + if self._timesteps_since_restore: + autofilled["timesteps_since_restore"] = self._timesteps_since_restore + + if debug_metrics_only: + autofilled = {k: v for k, v in autofilled.items() if k in DEBUG_METRICS} + return autofilled + + def is_actor(self): + try: + actor_id = ray._private.worker.global_worker.actor_id + return actor_id != actor_id.nil() + except Exception: + # If global_worker is not instantiated, we're not in an actor + return False + + def train_buffered(self, buffer_time_s: float, max_buffer_length: int = 1000): + """Runs multiple iterations of training. + + Calls ``train()`` internally. Collects and combines multiple results. + This function will run ``self.train()`` repeatedly until one of + the following conditions is met: 1) the maximum buffer length is + reached, 2) the maximum buffer time is reached, or 3) a checkpoint + was created. Even if the maximum time is reached, it will always + block until at least one result is received. + + Args: + buffer_time_s: Maximum time to buffer. The next result + received after this amount of time has passed will return + the whole buffer. + max_buffer_length: Maximum number of results to buffer. + + """ + results = [] + + now = time.time() + send_buffer_at = now + buffer_time_s + while now < send_buffer_at or not results: # At least one result + result = self.train() + results.append(result) + if result.get(DONE, False): + # If the trial is done, return + break + elif result.get(SHOULD_CHECKPOINT, False): + # If a checkpoint was created, return + break + elif result.get(RESULT_DUPLICATE): + # If the function API trainable completed, return + break + elif len(results) >= max_buffer_length: + # If the buffer is full, return + break + now = time.time() + + return results + + def train(self): + """Runs one logical iteration of training. + + Calls ``step()`` internally. Subclasses should override ``step()`` + instead to return results. + This method automatically fills the following fields in the result: + + `done` (bool): training is terminated. Filled only if not provided. + + `time_this_iter_s` (float): Time in seconds this iteration + took to run. This may be overridden in order to override the + system-computed time difference. + + `time_total_s` (float): Accumulated time in seconds for this + entire experiment. + + `training_iteration` (int): The index of this + training iteration, e.g. call to train(). This is incremented + after `step()` is called. + + `pid` (str): The pid of the training process. + + `date` (str): A formatted date of when the result was processed. + + `timestamp` (str): A UNIX timestamp of when the result + was processed. This may be overridden. + + `hostname` (str): Hostname of the machine hosting the training + process. + + `node_ip` (str): Node ip of the machine hosting the training + process. + + Returns: + A dict that describes training progress. + """ + start = time.time() + try: + result = self.step() + except Exception as e: + skipped = skip_exceptions(e) + raise skipped from exception_cause(skipped) + + assert isinstance(result, dict), "step() needs to return a dict." + + # We do not modify internal state nor update this result if duplicate. + if RESULT_DUPLICATE in result: + return result + + result = result.copy() + + self._iteration += 1 + self._iterations_since_restore += 1 + + if result.get(TIME_THIS_ITER_S) is not None: + time_this_iter = result[TIME_THIS_ITER_S] + else: + time_this_iter = time.time() - start + self._time_total += time_this_iter + self._time_since_restore += time_this_iter + + result_timestamp = result.get(TIMESTAMP, None) + + result.setdefault(DONE, False) + + # self._timesteps_total should only be tracked if increments are provided + if result.get(TIMESTEPS_THIS_ITER) is not None: + if self._timesteps_total is None: + self._timesteps_total = 0 + self._timesteps_total += result[TIMESTEPS_THIS_ITER] + self._timesteps_since_restore += result[TIMESTEPS_THIS_ITER] + + # self._episodes_total should only be tracked if increments provided + if result.get(EPISODES_THIS_ITER) is not None: + if self._episodes_total is None: + self._episodes_total = 0 + self._episodes_total += result[EPISODES_THIS_ITER] + + # self._timesteps_total should not override user-provided total + if self._timesteps_total is not None: + result.setdefault(TIMESTEPS_TOTAL, self._timesteps_total) + if self._episodes_total is not None: + result.setdefault(EPISODES_TOTAL, self._episodes_total) + result.setdefault(TRAINING_ITERATION, self._iteration) + + now = datetime.today() + result.update( + self.get_auto_filled_metrics( + now=now, time_this_iter=time_this_iter, timestamp=result_timestamp + ) + ) + + monitor_data = self._monitor.get_data() + if monitor_data: + result.update(monitor_data) + + self.log_result(result) + + if self._stdout_context: + self._stdout_stream.flush() + if self._stderr_context: + self._stderr_stream.flush() + + self._last_result = result + + if self._storage: + # Launch background tasks to sync artifacts at some specified frequency. + self._storage.persist_artifacts() + + return result + + def get_state(self): + return { + "iteration": self._iteration, + "timesteps_total": self._timesteps_total, + "time_total": self._time_total, + "episodes_total": self._episodes_total, + "last_result": self._last_result, + "ray_version": ray.__version__, + } + + def _report_class_trainable_checkpoint( + self, checkpoint_dir: str, checkpoint_dict_or_path: Union[str, Dict] + ) -> _TrainingResult: + """Report a checkpoint saved via Trainable.save_checkpoint. + + Need to handle both dict or path checkpoint returned by the user's + `save_checkpoint` method. + + This is to get class trainables to work with storage backend used by + function trainables. + This basically re-implements `train.report` for class trainables, + making sure to persist the checkpoint to storage. + """ + if isinstance(checkpoint_dict_or_path, dict): + with Path(checkpoint_dir, _DICT_CHECKPOINT_FILE_NAME).open("wb") as f: + ray_pickle.dump(checkpoint_dict_or_path, f) + elif isinstance(checkpoint_dict_or_path, str): + if checkpoint_dict_or_path != checkpoint_dir: + raise ValueError( + "The returned checkpoint path from `save_checkpoint` " + "must be None or the same as the provided path argument." + f"Got {checkpoint_dict_or_path} != {checkpoint_dir}" + ) + + local_checkpoint = Checkpoint.from_directory(checkpoint_dir) + + metrics = self._last_result.copy() if self._last_result else {} + + if self._storage: + # The checkpoint index is updated with the current result. + # NOTE: This is no longer using "iteration" as the folder indexing + # to be consistent with fn trainables. + self._storage._update_checkpoint_index(metrics) + + persisted_checkpoint = self._storage.persist_current_checkpoint( + local_checkpoint + ) + + checkpoint_result = _TrainingResult( + checkpoint=persisted_checkpoint, metrics=metrics + ) + # Persist trial artifacts to storage. + self._storage.persist_artifacts( + force=self._storage.sync_config.sync_artifacts_on_checkpoint + ) + else: + # `storage=None` only happens when initializing the + # Trainable manually, outside of Tune/Train. + # In this case, no storage is set, so the default behavior + # is to just not upload anything and report a local checkpoint. + # This is fine for the main use case of local debugging. + checkpoint_result = _TrainingResult( + checkpoint=local_checkpoint, metrics=metrics + ) + return checkpoint_result + + @DeveloperAPI + def save(self, checkpoint_dir: Optional[str] = None) -> _TrainingResult: + """Saves the current model state to a checkpoint. + + Subclasses should override ``save_checkpoint()`` instead to save state. + + Args: + checkpoint_dir: Optional dir to place the checkpoint. + + Returns: + The given or created checkpoint directory. + + Note the return value matches up with what is expected of `restore()`. + """ + if not isinstance(self, ray.tune.trainable.FunctionTrainable): + # Use a temporary directory if no checkpoint_dir is provided. + use_temp_dir = not checkpoint_dir + checkpoint_dir = checkpoint_dir or tempfile.mkdtemp() + os.makedirs(checkpoint_dir, exist_ok=True) + + checkpoint_dict_or_path = self.save_checkpoint(checkpoint_dir) + checkpoint_result = self._report_class_trainable_checkpoint( + checkpoint_dir, checkpoint_dict_or_path + ) + + # Clean up the temporary directory, since it's already been + # reported + persisted to storage. If no storage is set, the user is + # running the Trainable locally and is responsible for cleaning + # up the checkpoint directory themselves. + if use_temp_dir and self._storage: + shutil.rmtree(checkpoint_dir, ignore_errors=True) + else: + checkpoint_result: _TrainingResult = self.save_checkpoint(None) + assert isinstance(checkpoint_result, _TrainingResult) + assert self._last_result + # Update the checkpoint result to include auto-filled metrics. + checkpoint_result.metrics.update(self._last_result) + + return checkpoint_result + + @DeveloperAPI + def restore(self, checkpoint_path: Union[str, Checkpoint, _TrainingResult]): + """Restores training state from a given model checkpoint. + + These checkpoints are returned from calls to save(). + + Subclasses should override ``load_checkpoint()`` instead to + restore state. + This method restores additional metadata saved with the checkpoint. + + `checkpoint_path` should match with the return from ``save()``. + + Args: + checkpoint_path: Path to restore checkpoint from. If this + path does not exist on the local node, it will be fetched + from external (cloud) storage if available, or restored + from a remote node. + checkpoint_node_ip: If given, try to restore + checkpoint from this node if it doesn't exist locally or + on cloud storage. + fallback_to_latest: If True, will try to recover the + latest available checkpoint if the given ``checkpoint_path`` + could not be found. + + """ + # TODO(justinvyu): Clean up this interface + if isinstance(checkpoint_path, str): + checkpoint_path = Checkpoint.from_directory(checkpoint_path) + if isinstance(checkpoint_path, Checkpoint): + checkpoint_result = _TrainingResult(checkpoint=checkpoint_path, metrics={}) + else: + checkpoint_result: _TrainingResult = checkpoint_path + + assert isinstance(checkpoint_result, _TrainingResult), type(checkpoint_result) + checkpoint = checkpoint_result.checkpoint + checkpoint_metrics = checkpoint_result.metrics + self._iteration = checkpoint_metrics.get(TRAINING_ITERATION, 0) + self._time_total = checkpoint_metrics.get(TIME_TOTAL_S, 0) + self._time_since_restore = 0.0 + self._iterations_since_restore = 0 + + # TODO(justinvyu): This stuff should be moved to rllib. + self._timesteps_total = checkpoint_metrics.get(TIMESTEPS_TOTAL) + self._timesteps_since_restore = 0 + self._episodes_total = checkpoint_metrics.get(EPISODES_TOTAL) + + if not _exists_at_fs_path(checkpoint.filesystem, checkpoint.path): + raise ValueError( + f"Could not recover from checkpoint as it does not exist on " + f"storage anymore. " + f"Got storage fs type `{checkpoint.filesystem.type_name}` and " + f"path: {checkpoint.path}" + ) + + # TODO(justinvyu): [cls_trainable_support] + # This is to conform to the public class Trainable `load_checkpoint` API. + if not isinstance(self, ray.tune.trainable.FunctionTrainable): + # Need to convert Checkpoint -> local path or dict + # (depending on what the output of save_checkpoint was) + with checkpoint.as_directory() as checkpoint_dir: + checkpoint_path = Path(checkpoint_dir) + dict_checkpoint_file = checkpoint_path / _DICT_CHECKPOINT_FILE_NAME + if dict_checkpoint_file.exists(): + # If this was a dict checkpoint, load it as a dict + with open(dict_checkpoint_file, "rb") as f: + checkpoint_dict = ray_pickle.load(f) + self.load_checkpoint(checkpoint_dict) + else: + self.load_checkpoint(checkpoint_dir) + else: + # TODO(justinvyu): The Function Trainable case doesn't conform + # to the load_checkpoint API at the moment. + self.load_checkpoint(checkpoint_result) + + self._restored = True + + logger.info(f"Restored on {self._local_ip} from checkpoint: {checkpoint}") + + def export_model( + self, export_formats: Union[List[str], str], export_dir: Optional[str] = None + ): + """Exports model based on export_formats. + + Subclasses should override _export_model() to actually + export model to local directory. + + Args: + export_formats: Format or list of (str) formats + that should be exported. + export_dir: Optional dir to place the exported model. + Defaults to self.logdir. + + Returns: + A dict that maps ExportFormats to successfully exported models. + """ + if isinstance(export_formats, str): + export_formats = [export_formats] + export_dir = export_dir or self.logdir + return self._export_model(export_formats, export_dir) + + def reset(self, new_config, logger_creator=None, storage=None): + """Resets trial for use with new config. + + Subclasses should override reset_config() to actually + reset actor behavior for the new config.""" + self.config = new_config + + self._storage = storage + + trial_info = new_config.pop(TRIAL_INFO, None) + if trial_info: + self._trial_info = trial_info + + self._result_logger.flush() + self._result_logger.close() + + if logger_creator: + logger.debug("Logger reset.") + self._create_logger(new_config.copy(), logger_creator) + else: + logger.debug( + "Did not reset logger. Got: " + f"trainable.reset(logger_creator={logger_creator})." + ) + + stdout_file = new_config.pop(STDOUT_FILE, None) + stderr_file = new_config.pop(STDERR_FILE, None) + + self._close_logfiles() + self._open_logfiles(stdout_file, stderr_file) + + success = self.reset_config(new_config) + if not success: + return False + + # Reset attributes. Will be overwritten by `restore` if a checkpoint + # is provided. + self._iteration = 0 + self._time_total = 0.0 + self._timesteps_total = None + self._episodes_total = None + self._time_since_restore = 0.0 + self._timesteps_since_restore = 0 + self._iterations_since_restore = 0 + self._restored = False + + return True + + def reset_config(self, new_config: Dict): + """Resets configuration without restarting the trial. + + This method is optional, but can be implemented to speed up algorithms + such as PBT, and to allow performance optimizations such as running + experiments with reuse_actors=True. + + Args: + new_config: Updated hyperparameter configuration + for the trainable. + + Returns: + True if reset was successful else False. + """ + return False + + def _create_logger( + self, + config: Dict[str, Any], + logger_creator: Callable[[Dict[str, Any]], "Logger"] = None, + ): + """Create logger from logger creator. + + Sets _logdir and _result_logger. + + `_logdir` is the **per trial** directory for the Trainable. + """ + if logger_creator: + self._result_logger = logger_creator(config) + self._logdir = self._result_logger.logdir + else: + from ray.tune.logger import UnifiedLogger + + logdir_prefix = datetime.today().strftime("%Y-%m-%d_%H-%M-%S") + ray._private.utils.try_to_create_directory(DEFAULT_STORAGE_PATH) + self._logdir = tempfile.mkdtemp( + prefix=logdir_prefix, dir=DEFAULT_STORAGE_PATH + ) + self._result_logger = UnifiedLogger(config, self._logdir, loggers=None) + + def _open_logfiles(self, stdout_file, stderr_file): + """Create loggers. Open stdout and stderr logfiles.""" + if stdout_file: + stdout_path = (Path(self._logdir) / stdout_file).expanduser().as_posix() + self._stdout_fp = open(stdout_path, "a+") + self._stdout_stream = Tee(sys.stdout, self._stdout_fp) + self._stdout_context = redirect_stdout(self._stdout_stream) + self._stdout_context.__enter__() + + if stderr_file: + stderr_path = (Path(self._logdir) / stderr_file).expanduser().as_posix() + self._stderr_fp = open(stderr_path, "a+") + self._stderr_stream = Tee(sys.stderr, self._stderr_fp) + self._stderr_context = redirect_stderr(self._stderr_stream) + self._stderr_context.__enter__() + + # Add logging handler to root ray logger + formatter = logging.Formatter( + "[%(levelname)s %(asctime)s] " + "%(filename)s: %(lineno)d " + "%(message)s" + ) + self._stderr_logging_handler = logging.StreamHandler(self._stderr_fp) + self._stderr_logging_handler.setFormatter(formatter) + ray.logger.addHandler(self._stderr_logging_handler) + + def _close_logfiles(self): + """Close stdout and stderr logfiles.""" + if self._stderr_logging_handler: + ray.logger.removeHandler(self._stderr_logging_handler) + + if self._stdout_context: + self._stdout_stream.flush() + self._stdout_context.__exit__(None, None, None) + self._stdout_fp.close() + self._stdout_context = None + if self._stderr_context: + self._stderr_stream.flush() + self._stderr_context.__exit__(None, None, None) + self._stderr_fp.close() + self._stderr_context = None + + def stop(self): + """Releases all resources used by this trainable. + + Calls ``Trainable.cleanup`` internally. Subclasses should override + ``Trainable.cleanup`` for custom cleanup procedures. + """ + self._result_logger.flush() + self._result_logger.close() + if self._monitor.is_alive(): + self._monitor.stop() + self._monitor.join() + self.cleanup() + + self._close_logfiles() + + @property + def logdir(self): + """Directory of the results and checkpoints for this Trainable. + + Note that the current working directory will also be changed to this. + """ + return self._logdir + + @property + def trial_name(self): + """Trial name for the corresponding trial of this Trainable. + + This is not set if not using Tune. + + .. testcode:: + + from ray.tune import Trainable + + name = Trainable().trial_name + """ + if self._trial_info: + return self._trial_info.trial_name + else: + return "default" + + @property + def trial_id(self): + """Trial ID for the corresponding trial of this Trainable. + + This is not set if not using Tune. + + .. testcode:: + + from ray.tune import Trainable + + trial_id = Trainable().trial_id + """ + if self._trial_info: + return self._trial_info.trial_id + else: + return "default" + + @property + def trial_resources(self) -> Optional[PlacementGroupFactory]: + """Resources currently assigned to the trial of this Trainable. + + This is not set if not using Tune. + + .. testcode:: + + from ray.tune import Trainable + + trial_resources = Trainable().trial_resources + """ + if self._trial_info: + return self._trial_info.trial_resources + else: + return None + + @property + def iteration(self): + """Current training iteration. + + This value is automatically incremented every time `train()` is called + and is automatically inserted into the training result dict. + + """ + return self._iteration + + @property + def training_iteration(self): + """Current training iteration (same as `self.iteration`). + + This value is automatically incremented every time `train()` is called + and is automatically inserted into the training result dict. + + """ + return self._iteration + + def get_config(self): + """Returns configuration passed in by Tune.""" + return self.config + + def step(self): + """Subclasses should override this to implement train(). + + The return value will be automatically passed to the loggers. Users + can also return `tune.result.DONE` or `tune.result.SHOULD_CHECKPOINT` + as a key to manually trigger termination or checkpointing of this + trial. Note that manual checkpointing only works when subclassing + Trainables. + + .. versionadded:: 0.8.7 + + Returns: + A dict that describes training progress. + + """ + raise NotImplementedError + + def save_checkpoint(self, checkpoint_dir: str) -> Optional[Dict]: + """Subclasses should override this to implement ``save()``. + + Warning: + Do not rely on absolute paths in the implementation of + ``Trainable.save_checkpoint`` and ``Trainable.load_checkpoint``. + + Use ``validate_save_restore`` to catch ``Trainable.save_checkpoint``/ + ``Trainable.load_checkpoint`` errors before execution. + + >>> from ray.tune.utils import validate_save_restore + >>> MyTrainableClass = ... # doctest: +SKIP + >>> validate_save_restore(MyTrainableClass) # doctest: +SKIP + + .. versionadded:: 0.8.7 + + Args: + checkpoint_dir: The directory where the checkpoint + file must be stored. In a Tune run, if the trial is paused, + the provided path may be temporary and moved. + + Returns: + A dict or None. If dict, the return value will + be automatically serialized by Tune. In that case, + ``Trainable.load_checkpoint()`` will receive the dict upon restore. + + Example: + >>> trainable, trainable1, trainable2 = ... # doctest: +SKIP + >>> print(trainable1.save_checkpoint("/tmp/checkpoint_1")) # doctest: +SKIP + "/tmp/checkpoint_1" + >>> print(trainable2.save_checkpoint("/tmp/checkpoint_2")) # doctest: +SKIP + {"some": "data"} + >>> trainable.save_checkpoint("/tmp/bad_example") # doctest: +SKIP + "/tmp/NEW_CHECKPOINT_PATH/my_checkpoint_file" # This will error. + """ + raise NotImplementedError + + def load_checkpoint(self, checkpoint: Optional[Dict]): + """Subclasses should override this to implement restore(). + + Warning: + In this method, do not rely on absolute paths. The absolute + path of the checkpoint_dir used in ``Trainable.save_checkpoint`` + may be changed. + + If ``Trainable.save_checkpoint`` returned a prefixed string, the + prefix of the checkpoint string returned by + ``Trainable.save_checkpoint`` may be changed. + This is because trial pausing depends on temporary directories. + + The directory structure under the checkpoint_dir provided to + ``Trainable.save_checkpoint`` is preserved. + + See the examples below. + + Example: + >>> import os + >>> from ray.tune.trainable import Trainable + >>> class Example(Trainable): + ... def save_checkpoint(self, checkpoint_path): + ... my_checkpoint_path = os.path.join(checkpoint_path, "my/path") + ... return my_checkpoint_path + ... def load_checkpoint(self, my_checkpoint_path): + ... print(my_checkpoint_path) + >>> trainer = Example() + >>> # This is used when PAUSED. + >>> checkpoint_result = trainer.save() # doctest: +SKIP + >>> trainer.restore(checkpoint_result) # doctest: +SKIP + + If `Trainable.save_checkpoint` returned a dict, then Tune will directly pass + the dict data as the argument to this method. + + Example: + >>> from ray.tune.trainable import Trainable + >>> class Example(Trainable): + ... def save_checkpoint(self, checkpoint_path): + ... return {"my_data": 1} + ... def load_checkpoint(self, checkpoint_dict): + ... print(checkpoint_dict["my_data"]) + + .. versionadded:: 0.8.7 + + Args: + checkpoint: If dict, the return value is as + returned by ``save_checkpoint``. Otherwise, the directory + the checkpoint was stored in. + """ + raise NotImplementedError + + def setup(self, config: Dict): + """Subclasses should override this for custom initialization. + + .. versionadded:: 0.8.7 + + Args: + config: Hyperparameters and other configs given. + Copy of `self.config`. + + """ + pass + + def log_result(self, result: Dict): + """Subclasses can optionally override this to customize logging. + + The logging here is done on the worker process rather than + the driver. + + .. versionadded:: 0.8.7 + + Args: + result: Training result returned by step(). + """ + self._result_logger.on_result(result) + + def cleanup(self): + """Subclasses should override this for any cleanup on stop. + + If any Ray actors are launched in the Trainable (i.e., with a RLlib + trainer), be sure to kill the Ray actor process here. + + This process should be lightweight. Per default, + + You can kill a Ray actor by calling `ray.kill(actor)` + on the actor or removing all references to it and waiting for garbage + collection + + .. versionadded:: 0.8.7 + """ + pass + + def _export_model(self, export_formats: List[str], export_dir: str): + """Subclasses should override this to export model. + + Args: + export_formats: List of formats that should be exported. + export_dir: Directory to place exported models. + + Return: + A dict that maps ExportFormats to successfully exported models. + """ + return {} + + def _implements_method(self, key): + return hasattr(self, key) and callable(getattr(self, key)) diff --git a/.venv/lib/python3.11/site-packages/ray/tune/trainable/trainable_fn_utils.py b/.venv/lib/python3.11/site-packages/ray/tune/trainable/trainable_fn_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..2fcb1ac529aaaaeba8051ee4424bf29c61f37532 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/tune/trainable/trainable_fn_utils.py @@ -0,0 +1,56 @@ +from typing import Dict, Optional + +from ray.train._checkpoint import Checkpoint as TrainCheckpoint +from ray.train._internal.session import _warn_session_misuse, get_session +from ray.train.constants import _v2_migration_warnings_enabled +from ray.train.utils import _copy_doc, _log_deprecation_warning +from ray.util.annotations import PublicAPI + + +@_copy_doc(TrainCheckpoint) +class Checkpoint(TrainCheckpoint): + # NOTE: This is just a pass-through wrapper around `ray.train.Checkpoint` + # in order to detect whether the import module was correct `ray.tune.Checkpoint`. + pass + + +@PublicAPI(stability="stable") +@_warn_session_misuse() +def report(metrics: Dict, *, checkpoint: Optional[Checkpoint] = None) -> None: + """Report metrics and optionally save and register a checkpoint to Ray Tune. + + If a checkpoint is provided, it will be + :ref:`persisted to storage `. + + .. note:: + + Each invocation of this method will automatically increment the underlying + ``training_iteration`` number. The physical meaning of this "iteration" is + defined by user depending on how often they call ``report``. + It does not necessarily map to one epoch. + + Args: + metrics: The metrics you want to report. + checkpoint: The optional checkpoint you want to report. + """ + if checkpoint and not isinstance(checkpoint, Checkpoint): + if _v2_migration_warnings_enabled(): + _log_deprecation_warning( + "The `Checkpoint` class should be imported from `ray.tune` " + "when passing it to `ray.tune.report` in a Tune function." + "Please update your imports." + ) + + get_session().report(metrics, checkpoint=checkpoint) + + +@PublicAPI(stability="stable") +@_warn_session_misuse() +def get_checkpoint() -> Optional[Checkpoint]: + """Access the latest reported checkpoint to resume from if one exists.""" + + return get_session().loaded_checkpoint + + +def _in_tune_session() -> bool: + return get_session() and get_session().world_rank is None diff --git a/.venv/lib/python3.11/site-packages/ray/tune/trainable/util.py b/.venv/lib/python3.11/site-packages/ray/tune/trainable/util.py new file mode 100644 index 0000000000000000000000000000000000000000..5c637fd6bc00bda456661bed90edf9353588f1c2 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/tune/trainable/util.py @@ -0,0 +1,243 @@ +import inspect +import logging +import types +from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Type, Union + +import ray +from ray.tune.execution.placement_groups import ( + PlacementGroupFactory, + resource_dict_to_pg_factory, +) +from ray.tune.registry import _ParameterRegistry +from ray.util.annotations import PublicAPI + +if TYPE_CHECKING: + from ray.tune.trainable import Trainable + +logger = logging.getLogger(__name__) + + +@PublicAPI(stability="beta") +def with_parameters(trainable: Union[Type["Trainable"], Callable], **kwargs): + """Wrapper for trainables to pass arbitrary large data objects. + + This wrapper function will store all passed parameters in the Ray + object store and retrieve them when calling the function. It can thus + be used to pass arbitrary data, even datasets, to Tune trainables. + + This can also be used as an alternative to ``functools.partial`` to pass + default arguments to trainables. + + When used with the function API, the trainable function is called with + the passed parameters as keyword arguments. When used with the class API, + the ``Trainable.setup()`` method is called with the respective kwargs. + + If the data already exists in the object store (are instances of + ObjectRef), using ``tune.with_parameters()`` is not necessary. You can + instead pass the object refs to the training function via the ``config`` + or use Python partials. + + Args: + trainable: Trainable to wrap. + **kwargs: parameters to store in object store. + + Function API example: + + .. code-block:: python + + from ray import train, tune + + def train_fn(config, data=None): + for sample in data: + loss = update_model(sample) + train.report(loss=loss) + + data = HugeDataset(download=True) + + tuner = Tuner( + tune.with_parameters(train_fn, data=data), + # ... + ) + tuner.fit() + + Class API example: + + .. code-block:: python + + from ray import tune + + class MyTrainable(tune.Trainable): + def setup(self, config, data=None): + self.data = data + self.iter = iter(self.data) + self.next_sample = next(self.iter) + + def step(self): + loss = update_model(self.next_sample) + try: + self.next_sample = next(self.iter) + except StopIteration: + return {"loss": loss, done: True} + return {"loss": loss} + + data = HugeDataset(download=True) + + tuner = Tuner( + tune.with_parameters(MyTrainable, data=data), + # ... + ) + """ + from ray.tune.trainable import Trainable + + if not callable(trainable) or ( + inspect.isclass(trainable) and not issubclass(trainable, Trainable) + ): + raise ValueError( + f"`tune.with_parameters() only works with function trainables " + f"or classes that inherit from `tune.Trainable()`. Got type: " + f"{type(trainable)}." + ) + + parameter_registry = _ParameterRegistry() + ray._private.worker._post_init_hooks.append(parameter_registry.flush) + + # Objects are moved into the object store + prefix = f"{str(trainable)}_" + for k, v in kwargs.items(): + parameter_registry.put(prefix + k, v) + + trainable_name = getattr(trainable, "__name__", "tune_with_parameters") + keys = set(kwargs.keys()) + + if inspect.isclass(trainable): + # Class trainable + + class _Inner(trainable): + def setup(self, config): + setup_kwargs = {} + for k in keys: + setup_kwargs[k] = parameter_registry.get(prefix + k) + super(_Inner, self).setup(config, **setup_kwargs) + + trainable_with_params = _Inner + else: + # Function trainable + + def inner(config): + fn_kwargs = {} + for k in keys: + fn_kwargs[k] = parameter_registry.get(prefix + k) + return trainable(config, **fn_kwargs) + + trainable_with_params = inner + + if hasattr(trainable, "__mixins__"): + trainable_with_params.__mixins__ = trainable.__mixins__ + + # If the trainable has been wrapped with `tune.with_resources`, we should + # keep the `_resources` attribute around + if hasattr(trainable, "_resources"): + trainable_with_params._resources = trainable._resources + + trainable_with_params.__name__ = trainable_name + return trainable_with_params + + +@PublicAPI(stability="beta") +def with_resources( + trainable: Union[Type["Trainable"], Callable], + resources: Union[ + Dict[str, float], + PlacementGroupFactory, + Callable[[dict], PlacementGroupFactory], + ], +): + """Wrapper for trainables to specify resource requests. + + This wrapper allows specification of resource requirements for a specific + trainable. It will override potential existing resource requests (use + with caution!). + + The main use case is to request resources for function trainables when used + with the Tuner() API. + + Class trainables should usually just implement the ``default_resource_request()`` + method. + + Args: + trainable: Trainable to wrap. + resources: Resource dict, placement group factory, or callable that takes + in a config dict and returns a placement group factory. + + Example: + + .. code-block:: python + + from ray import tune + from ray.tune.tuner import Tuner + + def train_fn(config): + return len(ray.get_gpu_ids()) # Returns 2 + + tuner = Tuner( + tune.with_resources(train_fn, resources={"gpu": 2}), + # ... + ) + results = tuner.fit() + + """ + from ray.tune.trainable import Trainable + + if not callable(trainable) or ( + inspect.isclass(trainable) and not issubclass(trainable, Trainable) + ): + raise ValueError( + f"`tune.with_resources() only works with function trainables " + f"or classes that inherit from `tune.Trainable()`. Got type: " + f"{type(trainable)}." + ) + + if isinstance(resources, PlacementGroupFactory): + pgf = resources + elif isinstance(resources, dict): + pgf = resource_dict_to_pg_factory(resources) + elif callable(resources): + pgf = resources + else: + raise ValueError( + f"Invalid resource type for `with_resources()`: {type(resources)}" + ) + + if not inspect.isclass(trainable): + if isinstance(trainable, types.MethodType): + # Methods cannot set arbitrary attributes, so we have to wrap them + def _trainable(config): + return trainable(config) + + _trainable._resources = pgf + return _trainable + + # Just set an attribute. This will be resolved later in `wrap_function()`. + try: + trainable._resources = pgf + except AttributeError as e: + raise RuntimeError( + "Could not use `tune.with_resources()` on the supplied trainable. " + "Wrap your trainable in a regular function before passing it " + "to Ray Tune." + ) from e + else: + + class ResourceTrainable(trainable): + @classmethod + def default_resource_request( + cls, config: Dict[str, Any] + ) -> Optional[PlacementGroupFactory]: + if not isinstance(pgf, PlacementGroupFactory) and callable(pgf): + return pgf(config) + return pgf + + ResourceTrainable.__name__ = trainable.__name__ + trainable = ResourceTrainable + + return trainable diff --git a/.venv/lib/python3.11/site-packages/ray/tune/tune.py b/.venv/lib/python3.11/site-packages/ray/tune/tune.py new file mode 100644 index 0000000000000000000000000000000000000000..879bf53f7d58d5fee36ee516e8a5bf7be6c31d28 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/tune/tune.py @@ -0,0 +1,1161 @@ +import abc +import copy +import datetime +import logging +import os +import signal +import sys +import threading +import time +import warnings +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + Mapping, + Optional, + Sequence, + Type, + Union, +) + +import ray +from ray.air._internal import usage as air_usage +from ray.air._internal.usage import AirEntrypoint +from ray.air.util.node import _force_on_current_node +from ray.train import CheckpointConfig, SyncConfig +from ray.train.constants import _DEPRECATED_VALUE, RAY_CHDIR_TO_TRIAL_DIR +from ray.tune.analysis import ExperimentAnalysis +from ray.tune.callback import Callback +from ray.tune.error import TuneError +from ray.tune.execution.placement_groups import PlacementGroupFactory +from ray.tune.execution.tune_controller import TuneController +from ray.tune.experiment import Experiment, Trial, _convert_to_experiment_list +from ray.tune.experimental.output import IS_NOTEBOOK, AirVerbosity, get_air_verbosity +from ray.tune.impl.placeholder import create_resolvers_map, inject_placeholders +from ray.tune.logger import TBXLoggerCallback +from ray.tune.progress_reporter import ( + ProgressReporter, + _detect_progress_metrics, + _detect_reporter, + _prepare_progress_reporter_for_ray_client, + _stream_client_output, +) +from ray.tune.registry import get_trainable_cls + +# Must come last to avoid circular imports +from ray.tune.schedulers import ( + FIFOScheduler, + PopulationBasedTraining, + PopulationBasedTrainingReplay, + TrialScheduler, +) +from ray.tune.schedulers.util import ( + _set_search_properties_backwards_compatible as scheduler_set_search_props, +) +from ray.tune.search import ( + BasicVariantGenerator, + ConcurrencyLimiter, + SearchAlgorithm, + Searcher, + SearchGenerator, + create_searcher, +) +from ray.tune.search.util import ( + _set_search_properties_backwards_compatible as searcher_set_search_props, +) +from ray.tune.search.variant_generator import _has_unresolved_values +from ray.tune.stopper import Stopper +from ray.tune.trainable import Trainable +from ray.tune.tune_config import ResumeConfig +from ray.tune.utils.callback import _create_default_callbacks +from ray.tune.utils.log import Verbosity, has_verbosity, set_verbosity +from ray.util.annotations import PublicAPI +from ray.util.queue import Queue + +if TYPE_CHECKING: + import pyarrow.fs + + from ray.tune.experimental.output import ProgressReporter as AirProgressReporter + +logger = logging.getLogger(__name__) + + +def _get_trainable( + run_identifier: Union[Experiment, str, Type, Callable] +) -> Optional[Type[Trainable]]: + if isinstance(run_identifier, Experiment): + run_identifier = run_identifier.run_identifier + + if isinstance(run_identifier, type): + if not issubclass(run_identifier, Trainable): + # If obscure dtype, assume it is overridden. + return None + trainable_cls = run_identifier + elif callable(run_identifier): + trainable_cls = run_identifier + elif isinstance(run_identifier, str): + trainable_cls = get_trainable_cls(run_identifier) + else: + return None + + return trainable_cls + + +def _build_resume_config_from_legacy_config( + resume: Union[str, bool] +) -> Optional[ResumeConfig]: + """Converts the legacy resume (str, bool) to a ResumeConfig object. + Returns None if resume is False. + """ + if resume is False: + return None + if resume is True: + return ResumeConfig() + + # Parse resume string, e.g. AUTO+ERRORED + resume_settings = resume.split("+") + resume_str = resume_settings[0] + + if resume_str in ("LOCAL", "REMOTE", "PROMPT", "ERRORED_ONLY"): + raise DeprecationWarning( + f"'{resume_str}' is deprecated. " + "Please pass in one of (True, False, 'AUTO')." + ) + + resume_config = ResumeConfig() + for setting in resume_settings[1:]: + if setting == "ERRORED": + resume_config = ResumeConfig(errored=ResumeConfig.ResumeType.RESUME) + elif setting == "RESTART_ERRORED": + resume_config = ResumeConfig(errored=ResumeConfig.ResumeType.RESTART) + elif setting == "ERRORED_ONLY": + resume_config = ResumeConfig( + unfinished=ResumeConfig.ResumeType.SKIP, + errored=ResumeConfig.ResumeType.RESUME, + ) + elif setting == "RESTART_ERRORED_ONLY": + resume_config = ResumeConfig( + unfinished=ResumeConfig.ResumeType.SKIP, + errored=ResumeConfig.ResumeType.RESTART, + ) + else: + raise ValueError(f"Invalid resume setting: '{setting}'") + + return resume_config + + +def _check_default_resources_override( + run_identifier: Union[Experiment, str, Type, Callable] +) -> bool: + trainable_cls = _get_trainable(run_identifier) + if not trainable_cls: + # If no trainable, assume override + return True + + return hasattr(trainable_cls, "default_resource_request") and ( + trainable_cls.default_resource_request.__code__ + != Trainable.default_resource_request.__code__ + ) + + +def _check_mixin(run_identifier: Union[Experiment, str, Type, Callable]) -> bool: + trainable_cls = _get_trainable(run_identifier) + if not trainable_cls: + # Default to True + return True + + return hasattr(trainable_cls, "__mixins__") or getattr( + trainable_cls, "_is_mixin", False + ) + + +def _check_gpus_in_resources( + resources: Optional[Union[Dict, PlacementGroupFactory]] +) -> bool: + if not resources: + return False + + if isinstance(resources, PlacementGroupFactory): + return bool(resources.required_resources.get("GPU", None)) + + if isinstance(resources, dict): + return bool(resources.get("gpu", None)) + + +def _report_progress( + runner: TuneController, reporter: ProgressReporter, done: bool = False +): + """Reports experiment progress. + + Args: + runner: Trial runner to report on. + reporter: Progress reporter. + done: Whether this is the last progress report attempt. + """ + trials = runner.get_trials() + if reporter.should_report(trials, done=done): + sched_debug_str = runner.scheduler_alg.debug_string() + used_resources_str = runner._used_resources_string() + reporter.report(trials, done, sched_debug_str, used_resources_str) + + +def _report_air_progress( + runner: TuneController, reporter: "AirProgressReporter", force: bool = False +): + trials = runner.get_trials() + reporter_args = [] + used_resources_string = runner._used_resources_string() + reporter_args.append(used_resources_string) + reporter.print_heartbeat(trials, *reporter_args, force=force) + + +def _setup_signal_catching() -> threading.Event: + original_handler = signal.getsignal(signal.SIGINT) + experiment_interrupted_event = threading.Event() + + def signal_interrupt_tune_run(sig: int, frame): + logger.warning( + "Stop signal received (e.g. via SIGINT/Ctrl+C), ending Ray Tune run. " + "This will try to checkpoint the experiment state one last time. " + "Press CTRL+C (or send SIGINT/SIGKILL/SIGTERM) " + "to skip. " + ) + experiment_interrupted_event.set() + # Restore original signal handler to react to future SIGINT signals. + signal.signal(signal.SIGINT, original_handler) + + # We should only install the handler when it is safe to do so. + # When tune.run() is called from worker thread, signal.signal will + # fail. + allow_signal_catching = True + if threading.current_thread() != threading.main_thread(): + allow_signal_catching = False + + if allow_signal_catching: + if not int(os.getenv("TUNE_DISABLE_SIGINT_HANDLER", "0")): + signal.signal(signal.SIGINT, signal_interrupt_tune_run) + + # Always register SIGUSR1 if available (not available e.g. on Windows) + if hasattr(signal, "SIGUSR1"): + signal.signal(signal.SIGUSR1, signal_interrupt_tune_run) + + return experiment_interrupted_event + + +def _ray_auto_init(entrypoint: str): + """Initialize Ray unless already configured.""" + if os.environ.get("TUNE_DISABLE_AUTO_INIT") == "1": + logger.info("'TUNE_DISABLE_AUTO_INIT=1' detected.") + elif not ray.is_initialized(): + ray.init() + logger.info( + "Initializing Ray automatically. " + "For cluster usage or custom Ray initialization, " + f"call `ray.init(...)` before `{entrypoint}`." + ) + + +class _Config(abc.ABC): + def to_dict(self) -> dict: + """Converts this configuration to a dict format.""" + raise NotImplementedError + + +@PublicAPI +def run( + run_or_experiment: Union[str, Callable, Type], + *, + name: Optional[str] = None, + metric: Optional[str] = None, + mode: Optional[str] = None, + stop: Optional[Union[Mapping, Stopper, Callable[[str, Mapping], bool]]] = None, + time_budget_s: Optional[Union[int, float, datetime.timedelta]] = None, + config: Optional[Dict[str, Any]] = None, + resources_per_trial: Union[ + None, Mapping[str, Union[float, int, Mapping]], PlacementGroupFactory + ] = None, + num_samples: int = 1, + storage_path: Optional[str] = None, + storage_filesystem: Optional["pyarrow.fs.FileSystem"] = None, + search_alg: Optional[Union[Searcher, SearchAlgorithm, str]] = None, + scheduler: Optional[Union[TrialScheduler, str]] = None, + checkpoint_config: Optional[CheckpointConfig] = None, + verbose: Optional[Union[int, AirVerbosity, Verbosity]] = None, + progress_reporter: Optional[ProgressReporter] = None, + log_to_file: bool = False, + trial_name_creator: Optional[Callable[[Trial], str]] = None, + trial_dirname_creator: Optional[Callable[[Trial], str]] = None, + sync_config: Optional[SyncConfig] = None, + export_formats: Optional[Sequence] = None, + max_failures: int = 0, + fail_fast: bool = False, + restore: Optional[str] = None, + resume: Optional[Union[bool, str]] = None, + resume_config: Optional[ResumeConfig] = None, + reuse_actors: bool = False, + raise_on_failed_trial: bool = True, + callbacks: Optional[Sequence[Callback]] = None, + max_concurrent_trials: Optional[int] = None, + # Deprecated + keep_checkpoints_num: Optional[int] = None, # Deprecated (2.7) + checkpoint_score_attr: Optional[str] = None, # Deprecated (2.7) + checkpoint_freq: int = 0, # Deprecated (2.7) + checkpoint_at_end: bool = False, # Deprecated (2.7) + chdir_to_trial_dir: bool = _DEPRECATED_VALUE, # Deprecated (2.8) + local_dir: Optional[str] = None, + # == internal only == + _remote: Optional[bool] = None, + # Passed by the Tuner. + _remote_string_queue: Optional[Queue] = None, + # Todo (krfricke): Find a better way to pass entrypoint information, e.g. + # a context object or similar. + _entrypoint: AirEntrypoint = AirEntrypoint.TUNE_RUN, +) -> ExperimentAnalysis: + """Executes training. + + When a SIGINT signal is received (e.g. through Ctrl+C), the tuning run + will gracefully shut down and checkpoint the latest experiment state. + Sending SIGINT again (or SIGKILL/SIGTERM instead) will skip this step. + + Many aspects of Tune, such as the frequency of global checkpointing, + maximum pending placement group trials and the path of the result + directory be configured through environment variables. Refer to + :ref:`tune-env-vars` for a list of environment variables available. + + Examples: + + .. code-block:: python + + # Run 10 trials (each trial is one instance of a Trainable). Tune runs + # in parallel and automatically determines concurrency. + tune.run(trainable, num_samples=10) + + # Run 1 trial, stop when trial has reached 10 iterations + tune.run(my_trainable, stop={"training_iteration": 10}) + + # automatically retry failed trials up to 3 times + tune.run(my_trainable, stop={"training_iteration": 10}, max_failures=3) + + # Run 1 trial, search over hyperparameters, stop after 10 iterations. + space = {"lr": tune.uniform(0, 1), "momentum": tune.uniform(0, 1)} + tune.run(my_trainable, config=space, stop={"training_iteration": 10}) + + # Resumes training if a previous machine crashed + tune.run( + my_trainable, config=space, + storage_path=, name=, resume=True + ) + + Args: + run_or_experiment: If function|class|str, this is the algorithm or + model to train. This may refer to the name of a built-on algorithm + (e.g. RLlib's DQN or PPO), a user-defined trainable + function or class, or the string identifier of a + trainable function or class registered in the tune registry. + If Experiment, then Tune will execute training based on + Experiment.spec. If you want to pass in a Python lambda, you + will need to first register the function: + ``tune.register_trainable("lambda_id", lambda x: ...)``. You can + then use ``tune.run("lambda_id")``. + metric: Metric to optimize. This metric should be reported + with `tune.report()`. If set, will be passed to the search + algorithm and scheduler. + mode: Must be one of [min, max]. Determines whether objective is + minimizing or maximizing the metric attribute. If set, will be + passed to the search algorithm and scheduler. + name: Name of experiment. + stop: Stopping criteria. If dict, + the keys may be any field in the return result of 'train()', + whichever is reached first. If function, it must take (trial_id, + result) as arguments and return a boolean (True if trial should be + stopped, False otherwise). This can also be a subclass of + ``ray.tune.Stopper``, which allows users to implement + custom experiment-wide stopping (i.e., stopping an entire Tune + run based on some time constraint). + time_budget_s: Global time budget in + seconds after which all trials are stopped. Can also be a + ``datetime.timedelta`` object. + config: Algorithm-specific configuration for Tune variant + generation (e.g. env, hyperparams). Defaults to empty dict. + Custom search algorithms may ignore this. + resources_per_trial: Machine resources + to allocate per trial, e.g. ``{"cpu": 64, "gpu": 8}``. + Note that GPUs will not be assigned unless you specify them here. + Defaults to 1 CPU and 0 GPUs in + ``Trainable.default_resource_request()``. This can also + be a PlacementGroupFactory object wrapping arguments to create a + per-trial placement group. + num_samples: Number of times to sample from the + hyperparameter space. Defaults to 1. If `grid_search` is + provided as an argument, the grid will be repeated + `num_samples` of times. If this is -1, (virtually) infinite + samples are generated until a stopping condition is met. + storage_path: Path to store results at. Can be a local directory or + a destination on cloud storage. Defaults to + the local ``~/ray_results`` directory. + search_alg: Search algorithm for + optimization. You can also use the name of the algorithm. + scheduler: Scheduler for executing + the experiment. Choose among FIFO (default), MedianStopping, + AsyncHyperBand, HyperBand and PopulationBasedTraining. Refer to + ray.tune.schedulers for more options. You can also use the + name of the scheduler. + verbose: 0, 1, or 2. Verbosity mode. + 0 = silent, 1 = default, 2 = verbose. Defaults to 1. + If the ``RAY_AIR_NEW_OUTPUT=1`` environment variable is set, + uses the old verbosity settings: + 0 = silent, 1 = only status updates, 2 = status and brief + results, 3 = status and detailed results. + progress_reporter: Progress reporter for reporting + intermediate experiment progress. Defaults to CLIReporter if + running in command-line, or JupyterNotebookReporter if running in + a Jupyter notebook. + log_to_file: Log stdout and stderr to files in + Tune's trial directories. If this is `False` (default), no files + are written. If `true`, outputs are written to `trialdir/stdout` + and `trialdir/stderr`, respectively. If this is a single string, + this is interpreted as a file relative to the trialdir, to which + both streams are written. If this is a Sequence (e.g. a Tuple), + it has to have length 2 and the elements indicate the files to + which stdout and stderr are written, respectively. + trial_name_creator: Optional function that takes in a Trial and returns + its name (i.e. its string representation). Be sure to include some unique + identifier (such as `Trial.trial_id`) in each trial's name. + trial_dirname_creator: Optional function that takes in a trial and + generates its trial directory name as a string. Be sure to include some + unique identifier (such as `Trial.trial_id`) is used in each trial's + directory name. Otherwise, trials could overwrite artifacts and checkpoints + of other trials. The return value cannot be a path. + chdir_to_trial_dir: Deprecated. Set the `RAY_CHDIR_TO_TRIAL_DIR` env var instead + sync_config: Configuration object for syncing. See train.SyncConfig. + export_formats: List of formats that exported at the end of + the experiment. Default is None. + max_failures: Try to recover a trial at least this many times. + Ray will recover from the latest checkpoint if present. + Setting to -1 will lead to infinite recovery retries. + Setting to 0 will disable retries. Defaults to 0. + fail_fast: Whether to fail upon the first error. + If fail_fast='raise' provided, Tune will automatically + raise the exception received by the Trainable. fail_fast='raise' + can easily leak resources and should be used with caution (it + is best used with `ray.init(local_mode=True)`). + restore: Path to checkpoint. Only makes sense to set if + running 1 trial. Defaults to None. + resume: One of [True, False, "AUTO"]. Can + be suffixed with one or more of ["+ERRORED", "+ERRORED_ONLY", + "+RESTART_ERRORED", "+RESTART_ERRORED_ONLY"] (e.g. ``AUTO+ERRORED``). + `resume=True` and `resume="AUTO"` will attempt to resume from a + checkpoint and otherwise start a new experiment. + The suffix "+ERRORED" resets and reruns errored trials upon resume - + previous trial artifacts will be left untouched. It will try to continue + from the last observed checkpoint. + The suffix "+RESTART_ERRORED" will instead start the errored trials from + scratch. "+ERRORED_ONLY" and "+RESTART_ERRORED_ONLY" will disable + resuming non-errored trials - they will be added as finished instead. New + trials can still be generated by the search algorithm. + resume_config: [Experimental] Config object that controls how to resume + trials of different statuses. Can be used as a substitute to the + `resume` suffixes described above. + reuse_actors: Whether to reuse actors between different trials + when possible. This can drastically speed up experiments that start + and stop actors often (e.g., PBT in time-multiplexing mode). This + requires trials to have the same resource requirements. + Defaults to ``False``. + raise_on_failed_trial: Raise TuneError if there exists failed + trial (of ERROR state) when the experiments complete. + callbacks: List of callbacks that will be called at different + times in the training loop. Must be instances of the + ``ray.tune.callback.Callback`` class. If not passed, + `LoggerCallback` (json/csv/tensorboard) callbacks are automatically added. + max_concurrent_trials: Maximum number of trials to run + concurrently. Must be non-negative. If None or 0, no limit will + be applied. This is achieved by wrapping the ``search_alg`` in + a :class:`ConcurrencyLimiter`, and thus setting this argument + will raise an exception if the ``search_alg`` is already a + :class:`ConcurrencyLimiter`. Defaults to None. + _remote: Whether to run the Tune driver in a remote function. + This is disabled automatically if a custom trial executor is + passed in. This is enabled by default in Ray client mode. + local_dir: Deprecated. Use `storage_path` instead. + keep_checkpoints_num: Deprecated. use checkpoint_config instead. + checkpoint_score_attr: Deprecated. use checkpoint_config instead. + checkpoint_freq: Deprecated. use checkpoint_config instead. + checkpoint_at_end: Deprecated. use checkpoint_config instead. + checkpoint_keep_all_ranks: Deprecated. use checkpoint_config instead. + checkpoint_upload_from_workers: Deprecated. use checkpoint_config instead. + + Returns: + ExperimentAnalysis: Object for experiment analysis. + + Raises: + TuneError: Any trials failed and `raise_on_failed_trial` is True. + """ + # NO CODE IS TO BE ADDED ABOVE THIS COMMENT + # remote_run_kwargs must be defined before any other + # code is ran to ensure that at this point, + # `locals()` is equal to args and kwargs + remote_run_kwargs = locals().copy() + remote_run_kwargs.pop("_remote") + + if _entrypoint == AirEntrypoint.TRAINER: + error_message_map = { + "entrypoint": "(...)", + "search_space_arg": "param_space", + "restore_entrypoint": '.restore(path="{path}", ...)', + } + elif _entrypoint == AirEntrypoint.TUNER: + error_message_map = { + "entrypoint": "Tuner(...)", + "search_space_arg": "param_space", + "restore_entrypoint": 'Tuner.restore(path="{path}", trainable=...)', + } + elif _entrypoint == AirEntrypoint.TUNE_RUN_EXPERIMENTS: + error_message_map = { + "entrypoint": "tune.run_experiments(...)", + "search_space_arg": "experiment=Experiment(config)", + "restore_entrypoint": "tune.run_experiments(..., resume=True)", + } + else: + error_message_map = { + "entrypoint": "tune.run(...)", + "search_space_arg": "config", + "restore_entrypoint": "tune.run(..., resume=True)", + } + + _ray_auto_init(entrypoint=error_message_map["entrypoint"]) + + if _remote is None: + _remote = ray.util.client.ray.is_connected() + + if verbose is None: + # Default `verbose` value. For new output engine, this is AirVerbosity.VERBOSE. + # For old output engine, this is Verbosity.V3_TRIAL_DETAILS + verbose = get_air_verbosity(AirVerbosity.VERBOSE) or Verbosity.V3_TRIAL_DETAILS + + if _remote: + if get_air_verbosity(verbose) is not None: + logger.info( + "[output] This uses the legacy output and progress reporter, " + "as Ray client is not supported by the new engine. " + "For more information, see " + "https://github.com/ray-project/ray/issues/36949" + ) + + remote_run = ray.remote(num_cpus=0)(run) + + # Make sure tune.run is called on the sever node. + remote_run = _force_on_current_node(remote_run) + + progress_reporter, string_queue = _prepare_progress_reporter_for_ray_client( + progress_reporter, verbose, _remote_string_queue + ) + + # Override with detected progress reporter + remote_run_kwargs["progress_reporter"] = progress_reporter + + remote_future = remote_run.remote(_remote=False, **remote_run_kwargs) + + _stream_client_output( + remote_future, + progress_reporter, + string_queue, + ) + return ray.get(remote_future) + + del remote_run_kwargs + + # TODO(justinvyu): [Deprecated] Remove in 2.30 + ENV_VAR_DEPRECATION_MESSAGE = ( + "The environment variable `{}` is deprecated. " + "It is no longer used and will not have any effect. " + "You should set the `storage_path` instead. Files will no longer be " + "written to `~/ray_results` as long as `storage_path` is set." + "See the docs: https://docs.ray.io/en/latest/train/user-guides/" + "persistent-storage.html#setting-the-local-staging-directory" + ) + if os.environ.get("TUNE_RESULT_DIR"): + raise DeprecationWarning(ENV_VAR_DEPRECATION_MESSAGE.format("TUNE_RESULT_DIR")) + + if os.environ.get("RAY_AIR_LOCAL_CACHE_DIR"): + raise DeprecationWarning( + ENV_VAR_DEPRECATION_MESSAGE.format("RAY_AIR_LOCAL_CACHE_DIR") + ) + + if local_dir is not None: + raise DeprecationWarning( + "The `local_dir` argument is deprecated. " + "You should set the `storage_path` instead. " + "See the docs: https://docs.ray.io/en/latest/train/user-guides/" + "persistent-storage.html#setting-the-local-staging-directory" + ) + + ray._private.usage.usage_lib.record_library_usage("tune") + + # Tracking environment variable usage here will also catch: + # 1.) Tuner.fit() usage + # 2.) Trainer.fit() usage + # 3.) Ray client usage (env variables are inherited by the Ray runtime env) + air_usage.tag_ray_air_env_vars() + + # Track the entrypoint to AIR: + # Tuner.fit / Trainer.fit / tune.run / tune.run_experiments + air_usage.tag_air_entrypoint(_entrypoint) + + all_start = time.time() + + if mode and mode not in ["min", "max"]: + raise ValueError( + f"The `mode` parameter passed to `{error_message_map['entrypoint']}` " + "must be one of ['min', 'max']" + ) + + air_verbosity = get_air_verbosity(verbose) + if air_verbosity is not None and IS_NOTEBOOK: + logger.info( + "[output] This uses the legacy output and progress reporter, " + "as Jupyter notebooks are not supported by the new engine, yet. " + "For more information, please see " + "https://github.com/ray-project/ray/issues/36949" + ) + air_verbosity = None + + if air_verbosity is not None: + # Disable old output engine + set_verbosity(0) + else: + # Use old output engine + set_verbosity(verbose) + + config = config or {} + if isinstance(config, _Config): + config = config.to_dict() + if not isinstance(config, dict): + raise ValueError( + f"The `{error_message_map['search_space_arg']}` passed to " + f"`{error_message_map['entrypoint']}` must be a dict. " + f"Got '{type(config)}' instead." + ) + + sync_config = sync_config or SyncConfig() + checkpoint_config = checkpoint_config or CheckpointConfig() + + # For backward compatibility + # TODO(jungong): remove after 2.7 release. + if keep_checkpoints_num is not None: + warnings.warn( + "keep_checkpoints_num is deprecated and will be removed. " + "use checkpoint_config.num_to_keep instead.", + DeprecationWarning, + ) + checkpoint_config.num_to_keep = keep_checkpoints_num + if checkpoint_score_attr is not None: + warnings.warn( + "checkpoint_score_attr is deprecated and will be removed. " + "use checkpoint_config.checkpoint_score_attribute instead.", + DeprecationWarning, + ) + + if checkpoint_score_attr.startswith("min-"): + warnings.warn( + "using min- and max- prefixes to specify checkpoint score " + "order is deprecated. Use CheckpointConfig.checkpoint_score_order " + "instead", + DeprecationWarning, + ) + checkpoint_config.checkpoint_score_attribute = checkpoint_score_attr[4:] + checkpoint_config.checkpoint_score_order = "min" + else: + checkpoint_config.checkpoint_score_attribute = checkpoint_score_attr + checkpoint_config.checkpoint_score_order = "max" + + checkpoint_config.score_attr = checkpoint_score_attr + if checkpoint_freq > 0: + warnings.warn( + "checkpoint_freq is deprecated and will be removed. " + "use checkpoint_config.checkpoint_frequency instead.", + DeprecationWarning, + ) + checkpoint_config.checkpoint_frequency = checkpoint_freq + if checkpoint_at_end: + warnings.warn( + "checkpoint_at_end is deprecated and will be removed. " + "use checkpoint_config.checkpoint_at_end instead.", + DeprecationWarning, + ) + checkpoint_config.checkpoint_at_end = checkpoint_at_end + + # TODO(justinvyu): [Deprecated] Remove in 2.11. + if chdir_to_trial_dir != _DEPRECATED_VALUE: + raise DeprecationWarning( + "`chdir_to_trial_dir` is deprecated. " + f"Use the {RAY_CHDIR_TO_TRIAL_DIR} environment variable instead. " + "Set it to 0 to disable the default behavior of changing the " + "working directory.", + DeprecationWarning, + ) + + if num_samples == -1: + num_samples = sys.maxsize + + # Create scheduler here as we need access to some of its properties + if isinstance(scheduler, str): + # importing at top level causes a recursive dependency + from ray.tune.schedulers import create_scheduler + + scheduler = create_scheduler(scheduler) + scheduler = scheduler or FIFOScheduler() + + if not scheduler.supports_buffered_results: + # Result buffering with e.g. a Hyperband scheduler is a bad idea, as + # hyperband tries to stop trials when processing brackets. With result + # buffering, we might trigger this multiple times when evaluating + # a single trial, which leads to unexpected behavior. + env_result_buffer_length = os.getenv("TUNE_RESULT_BUFFER_LENGTH", "") + if env_result_buffer_length: + warnings.warn( + f"You are using a {type(scheduler)} scheduler, but " + f"TUNE_RESULT_BUFFER_LENGTH is set " + f"({env_result_buffer_length}). This can lead to undesired " + f"and faulty behavior, so the buffer length was forcibly set " + f"to 1 instead." + ) + os.environ["TUNE_RESULT_BUFFER_LENGTH"] = "1" + + if ( + isinstance(scheduler, (PopulationBasedTraining, PopulationBasedTrainingReplay)) + and not reuse_actors + ): + warnings.warn( + "Consider boosting PBT performance by enabling `reuse_actors` as " + "well as implementing `reset_config` for Trainable." + ) + + # Before experiments are created, we first clean up the passed in + # Config dictionary by replacing all the non-primitive config values + # with placeholders. This serves two purposes: + # 1. we can replace and "fix" these objects if a Trial is restored. + # 2. the config dictionary will then be compatible with all supported + # search algorithms, since a lot of them do not support non-primitive + # config values. + placeholder_resolvers = create_resolvers_map() + config = inject_placeholders( + # Make a deep copy here to avoid modifying the original config dict. + copy.deepcopy(config), + placeholder_resolvers, + ) + + # TODO(justinvyu): We should remove the ability to pass a list of + # trainables to tune.run. + if isinstance(run_or_experiment, list): + experiments = run_or_experiment + else: + experiments = [run_or_experiment] + + for i, exp in enumerate(experiments): + if not isinstance(exp, Experiment): + experiments[i] = Experiment( + name=name, + run=exp, + stop=stop, + time_budget_s=time_budget_s, + config=config, + resources_per_trial=resources_per_trial, + num_samples=num_samples, + storage_path=storage_path, + storage_filesystem=storage_filesystem, + sync_config=sync_config, + checkpoint_config=checkpoint_config, + trial_name_creator=trial_name_creator, + trial_dirname_creator=trial_dirname_creator, + log_to_file=log_to_file, + export_formats=export_formats, + max_failures=max_failures, + restore=restore, + ) + + if fail_fast and max_failures != 0: + raise ValueError("max_failures must be 0 if fail_fast=True.") + + if isinstance(search_alg, str): + search_alg = create_searcher(search_alg) + + # if local_mode=True is set during ray.init(). + is_local_mode = ray._private.worker._mode() == ray._private.worker.LOCAL_MODE + + if is_local_mode: + max_concurrent_trials = 1 + + if not search_alg: + search_alg = BasicVariantGenerator(max_concurrent=max_concurrent_trials or 0) + elif max_concurrent_trials or is_local_mode: + if isinstance(search_alg, ConcurrencyLimiter): + if not is_local_mode: + if search_alg.max_concurrent != max_concurrent_trials: + raise ValueError( + "You have specified `max_concurrent_trials=" + f"{max_concurrent_trials}`, but the `search_alg` is " + "already a `ConcurrencyLimiter` with `max_concurrent=" + f"{search_alg.max_concurrent}. FIX THIS by setting " + "`max_concurrent_trials=None`." + ) + else: + logger.warning( + "You have specified `max_concurrent_trials=" + f"{max_concurrent_trials}`, but the `search_alg` is " + "already a `ConcurrencyLimiter`. " + "`max_concurrent_trials` will be ignored." + ) + else: + if max_concurrent_trials < 1: + raise ValueError( + "`max_concurrent_trials` must be greater or equal than 1, " + f"got {max_concurrent_trials}." + ) + if isinstance(search_alg, Searcher): + search_alg = ConcurrencyLimiter( + search_alg, max_concurrent=max_concurrent_trials + ) + elif not is_local_mode: + logger.warning( + "You have passed a `SearchGenerator` instance as the " + "`search_alg`, but `max_concurrent_trials` requires a " + "`Searcher` instance`. `max_concurrent_trials` " + "will be ignored." + ) + + if isinstance(search_alg, Searcher): + search_alg = SearchGenerator(search_alg) + + if config and not searcher_set_search_props( + search_alg.set_search_properties, + metric, + mode, + config, + **experiments[0].public_spec, + ): + if _has_unresolved_values(config): + raise ValueError( + f"You passed a `{error_message_map['search_space_arg']}` parameter to " + f"`{error_message_map['entrypoint']}` with " + "unresolved parameters, but the search algorithm was already " + "instantiated with a search space. Make sure that `config` " + "does not contain any more parameter definitions - include " + "them in the search algorithm's search space if necessary." + ) + + if not scheduler_set_search_props( + scheduler.set_search_properties, metric, mode, **experiments[0].public_spec + ): + raise ValueError( + "You passed a `metric` or `mode` argument to " + f"`{error_message_map['entrypoint']}`, but " + "the scheduler you are using was already instantiated with their " + "own `metric` and `mode` parameters. Either remove the arguments " + f"from your scheduler or from `{error_message_map['entrypoint']}` args." + ) + + progress_metrics = _detect_progress_metrics(_get_trainable(run_or_experiment)) + + air_usage.tag_storage_type(experiments[0].storage) + + # NOTE: Report callback telemetry before populating the list with default callbacks. + # This tracks user-specified callback usage. + air_usage.tag_callbacks(callbacks) + + # Create default logging + syncer callbacks + callbacks = _create_default_callbacks( + callbacks, + air_verbosity=air_verbosity, + entrypoint=_entrypoint, + config=config, + metric=metric, + mode=mode, + progress_metrics=progress_metrics, + ) + + # User Warning for GPUs + if ray.cluster_resources().get("GPU", 0): + if _check_gpus_in_resources(resources=resources_per_trial): + # "gpu" is manually set. + pass + elif _check_default_resources_override(experiments[0].run_identifier): + # "default_resources" is manually overridden. + pass + else: + logger.warning( + "Tune detects GPUs, but no trials are using GPUs. " + "To enable trials to use GPUs, wrap `train_func` with " + "`tune.with_resources(train_func, resources_per_trial={'gpu': 1})` " + "which allows Tune to expose 1 GPU to each trial. " + "For Ray Train Trainers, you can specify GPU resources " + "through `ScalingConfig(use_gpu=True)`. " + "You can also override " + "`Trainable.default_resource_request` if using the " + "Trainable API." + ) + + experiment_interrupted_event = _setup_signal_catching() + + if progress_reporter and air_verbosity is not None: + logger.warning( + "AIR_VERBOSITY is set, ignoring passed-in ProgressReporter for now." + ) + progress_reporter = None + + if air_verbosity is None: + is_trainer = _entrypoint == AirEntrypoint.TRAINER + progress_reporter = progress_reporter or _detect_reporter( + _trainer_api=is_trainer + ) + + if resume is not None: + resume_config = resume_config or _build_resume_config_from_legacy_config(resume) + + runner_kwargs = dict( + search_alg=search_alg, + placeholder_resolvers=placeholder_resolvers, + scheduler=scheduler, + stopper=experiments[0].stopper, + resume_config=resume_config, + fail_fast=fail_fast, + callbacks=callbacks, + metric=metric, + trial_checkpoint_config=experiments[0].checkpoint_config, + reuse_actors=reuse_actors, + storage=experiments[0].storage, + _trainer_api=_entrypoint == AirEntrypoint.TRAINER, + ) + + runner = TuneController(**runner_kwargs) + + if not runner.resumed: + for exp in experiments: + search_alg.add_configurations([exp]) + # search_alg.total_samples has been updated, so we should + # update the number of pending trials + runner.update_max_pending_trials() + else: + logger.debug( + "You have resumed the Tune run, which means that any newly specified " + "`Experiment`s will be ignored. " + "Tune will just continue what was previously running." + ) + if resources_per_trial: + runner.update_pending_trial_resources(resources_per_trial) + + # Calls setup on callbacks + runner.setup_experiments( + experiments=experiments, total_num_samples=search_alg.total_samples + ) + + tune_start = time.time() + + air_progress_reporter = None + if air_verbosity is None: + progress_reporter.setup( + start_time=tune_start, + total_samples=search_alg.total_samples, + metric=metric, + mode=mode, + ) + else: + from ray.tune.experimental.output import ProgressReporter as AirProgressReporter + + for callback in callbacks: + if isinstance(callback, AirProgressReporter): + air_progress_reporter = callback + air_progress_reporter.setup( + start_time=tune_start, total_samples=search_alg.total_samples + ) + break + + experiment_local_path = runner._storage.experiment_driver_staging_path + experiment_dir_name = runner._storage.experiment_dir_name + + if any(isinstance(cb, TBXLoggerCallback) for cb in callbacks): + tensorboard_path = experiment_local_path + else: + tensorboard_path = None + + if air_progress_reporter: + air_progress_reporter.experiment_started( + experiment_name=experiment_dir_name, + experiment_path=runner.experiment_path, + searcher_str=search_alg.__class__.__name__, + scheduler_str=scheduler.__class__.__name__, + total_num_samples=search_alg.total_samples, + tensorboard_path=tensorboard_path, + ) + + try: + while not runner.is_finished() and not experiment_interrupted_event.is_set(): + runner.step() + if has_verbosity(Verbosity.V1_EXPERIMENT): + _report_progress(runner, progress_reporter) + + if air_verbosity is not None: + _report_air_progress(runner, air_progress_reporter) + except Exception: + runner.cleanup() + raise + + tune_taken = time.time() - tune_start + + final_sync_start = time.time() + try: + runner.checkpoint(force=True, wait=True) + logger.info( + "Wrote the latest version of all result files and experiment state to " + f"'{runner.experiment_path}' in {time.time() - final_sync_start:.4f}s." + ) + except Exception: + logger.error( + "Experiment state snapshotting failed:", exc_info=True, stack_info=True + ) + + if has_verbosity(Verbosity.V1_EXPERIMENT): + _report_progress(runner, progress_reporter, done=True) + + if air_verbosity is not None: + _report_air_progress(runner, air_progress_reporter, force=True) + + all_trials = runner.get_trials() + + runner.cleanup() + + incomplete_trials = [] + for trial in all_trials: + if trial.status != Trial.TERMINATED: + incomplete_trials += [trial] + + if incomplete_trials: + if raise_on_failed_trial and not experiment_interrupted_event.is_set(): + raise TuneError("Trials did not complete", incomplete_trials) + else: + logger.error("Trials did not complete: %s", incomplete_trials) + + all_taken = time.time() - all_start + if has_verbosity(Verbosity.V1_EXPERIMENT): + logger.info( + f"Total run time: {all_taken:.2f} seconds " + f"({tune_taken:.2f} seconds for the tuning loop)." + ) + + if experiment_interrupted_event.is_set(): + restore_entrypoint = error_message_map["restore_entrypoint"].format( + path=runner.experiment_path, + ) + if _entrypoint == AirEntrypoint.TRAINER: + logger.warning( + f"Training has been interrupted, but the most recent state was saved.\n" + f"Resume training with: {restore_entrypoint}" + ) + else: + logger.warning( + f"Experiment has been interrupted, but the most recent state was " + f"saved.\nResume experiment with: {restore_entrypoint}" + ) + + return ExperimentAnalysis( + experiment_checkpoint_path=runner.experiment_path, + default_metric=metric, + default_mode=mode, + trials=all_trials, + storage_filesystem=experiments[0].storage.storage_filesystem, + ) + + +@PublicAPI +def run_experiments( + experiments: Union[Experiment, Mapping, Sequence[Union[Experiment, Mapping]]], + scheduler: Optional[TrialScheduler] = None, + verbose: Optional[Union[int, AirVerbosity, Verbosity]] = None, + progress_reporter: Optional[ProgressReporter] = None, + resume: Optional[Union[bool, str]] = None, + resume_config: Optional[ResumeConfig] = None, + reuse_actors: bool = False, + raise_on_failed_trial: bool = True, + concurrent: bool = True, + callbacks: Optional[Sequence[Callback]] = None, + _remote: Optional[bool] = None, +): + """Runs and blocks until all trials finish. + + Example: + >>> from ray.tune.experiment import Experiment + >>> from ray.tune.tune import run_experiments + >>> def my_func(config): return {"score": 0} + >>> experiment_spec = Experiment("experiment", my_func) # doctest: +SKIP + >>> run_experiments(experiments=experiment_spec) # doctest: +SKIP + >>> experiment_spec = {"experiment": {"run": my_func}} # doctest: +SKIP + >>> run_experiments(experiments=experiment_spec) # doctest: +SKIP + + Returns: + List of Trial objects, holding data for each executed trial. + + """ + if _remote is None: + _remote = ray.util.client.ray.is_connected() + + _ray_auto_init(entrypoint="tune.run_experiments(...)") + + if verbose is None: + # Default `verbose` value. For new output engine, this is AirVerbosity.VERBOSE. + # For old output engine, this is Verbosity.V3_TRIAL_DETAILS + verbose = get_air_verbosity(AirVerbosity.VERBOSE) or Verbosity.V3_TRIAL_DETAILS + + if _remote: + if get_air_verbosity(verbose) is not None: + logger.info( + "[output] This uses the legacy output and progress reporter, " + "as Ray client is not supported by the new engine. " + "For more information, see " + "https://github.com/ray-project/ray/issues/36949" + ) + remote_run = ray.remote(num_cpus=0)(run_experiments) + + # Make sure tune.run_experiments is run on the server node. + remote_run = _force_on_current_node(remote_run) + + return ray.get( + remote_run.remote( + experiments, + scheduler, + verbose, + progress_reporter, + resume, + resume_config, + reuse_actors, + raise_on_failed_trial, + concurrent, + callbacks, + _remote=False, + ) + ) + + # This is important to do this here + # because it schematize the experiments + # and it conducts the implicit registration. + experiments = _convert_to_experiment_list(experiments) + + tune_run_params = dict( + verbose=verbose, + progress_reporter=progress_reporter, + resume=resume, + resume_config=resume_config, + reuse_actors=reuse_actors, + raise_on_failed_trial=raise_on_failed_trial, + scheduler=scheduler, + callbacks=callbacks, + _entrypoint=AirEntrypoint.TUNE_RUN_EXPERIMENTS, + ) + + if concurrent: + return run(experiments, **tune_run_params).trials + else: + trials = [] + for exp in experiments: + trials += run(exp, **tune_run_params).trials + return trials diff --git a/.venv/lib/python3.11/site-packages/ray/tune/tune_config.py b/.venv/lib/python3.11/site-packages/ray/tune/tune_config.py new file mode 100644 index 0000000000000000000000000000000000000000..ebaba70cdef8ee9495406bed1d6db30b74a1ac26 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/tune/tune_config.py @@ -0,0 +1,99 @@ +import datetime +from dataclasses import dataclass +from enum import Enum +from typing import Callable, Optional, Union + +from ray.train.constants import _DEPRECATED_VALUE +from ray.tune.experiment.trial import Trial +from ray.tune.schedulers import TrialScheduler +from ray.tune.search import SearchAlgorithm, Searcher +from ray.util.annotations import DeveloperAPI, PublicAPI + + +@dataclass +@PublicAPI(stability="beta") +class TuneConfig: + """Tune specific configs. + + Args: + metric: Metric to optimize. This metric should be reported + with `tune.report()`. If set, will be passed to the search + algorithm and scheduler. + mode: Must be one of [min, max]. Determines whether objective is + minimizing or maximizing the metric attribute. If set, will be + passed to the search algorithm and scheduler. + search_alg: Search algorithm for optimization. Default to + random search. + scheduler: Scheduler for executing the experiment. + Choose among FIFO (default), MedianStopping, + AsyncHyperBand, HyperBand and PopulationBasedTraining. Refer to + ray.tune.schedulers for more options. + num_samples: Number of times to sample from the + hyperparameter space. Defaults to 1. If `grid_search` is + provided as an argument, the grid will be repeated + `num_samples` of times. If this is -1, (virtually) infinite + samples are generated until a stopping condition is met. + max_concurrent_trials: Maximum number of trials to run + concurrently. Must be non-negative. If None or 0, no limit will + be applied. This is achieved by wrapping the ``search_alg`` in + a :class:`ConcurrencyLimiter`, and thus setting this argument + will raise an exception if the ``search_alg`` is already a + :class:`ConcurrencyLimiter`. Defaults to None. + time_budget_s: Global time budget in + seconds after which all trials are stopped. Can also be a + ``datetime.timedelta`` object. + reuse_actors: Whether to reuse actors between different trials + when possible. This can drastically speed up experiments that start + and stop actors often (e.g., PBT in time-multiplexing mode). This + requires trials to have the same resource requirements. + Defaults to ``False``. + trial_name_creator: Optional function that takes in a Trial and returns + its name (i.e. its string representation). Be sure to include some unique + identifier (such as `Trial.trial_id`) in each trial's name. + NOTE: This API is in alpha and subject to change. + trial_dirname_creator: Optional function that takes in a trial and + generates its trial directory name as a string. Be sure to include some + unique identifier (such as `Trial.trial_id`) is used in each trial's + directory name. Otherwise, trials could overwrite artifacts and checkpoints + of other trials. The return value cannot be a path. + NOTE: This API is in alpha and subject to change. + chdir_to_trial_dir: Deprecated. Set the `RAY_CHDIR_TO_TRIAL_DIR` env var instead + """ + + # Currently this is not at feature parity with `tune.run`, nor should it be. + # The goal is to reach a fine balance between API flexibility and conciseness. + # We should carefully introduce arguments here instead of just dumping everything. + mode: Optional[str] = None + metric: Optional[str] = None + search_alg: Optional[Union[Searcher, SearchAlgorithm]] = None + scheduler: Optional[TrialScheduler] = None + num_samples: int = 1 + max_concurrent_trials: Optional[int] = None + time_budget_s: Optional[Union[int, float, datetime.timedelta]] = None + reuse_actors: bool = False + trial_name_creator: Optional[Callable[[Trial], str]] = None + trial_dirname_creator: Optional[Callable[[Trial], str]] = None + chdir_to_trial_dir: bool = _DEPRECATED_VALUE + + +@DeveloperAPI +@dataclass +class ResumeConfig: + """[Experimental] This config is used to specify how to resume Tune trials.""" + + class ResumeType(Enum): + """An enumeration to define resume types for various trial states. + + Members: + RESUME: Resume from the latest checkpoint. + RESTART: Restart from the beginning (with no checkpoint). + SKIP: Skip this trial when resuming by treating it as terminated. + """ + + RESUME = "resume" + RESTART = "restart" + SKIP = "skip" + + finished: str = ResumeType.SKIP + unfinished: str = ResumeType.RESUME + errored: str = ResumeType.SKIP diff --git a/.venv/lib/python3.11/site-packages/ray/tune/tuner.py b/.venv/lib/python3.11/site-packages/ray/tune/tuner.py new file mode 100644 index 0000000000000000000000000000000000000000..86d7cae5537583003310ddd9e79cca9a483a5bf1 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/tune/tuner.py @@ -0,0 +1,434 @@ +import logging +import os +from pathlib import Path +from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Type, Union + +import pyarrow.fs + +import ray +from ray.air._internal.usage import AirEntrypoint +from ray.air.config import RunConfig +from ray.air.util.node import _force_on_current_node +from ray.train._internal.storage import _exists_at_fs_path, get_fs_and_path +from ray.tune import ResumeConfig +from ray.tune.experimental.output import get_air_verbosity +from ray.tune.impl.tuner_internal import _TUNER_PKL, TunerInternal +from ray.tune.progress_reporter import ( + _prepare_progress_reporter_for_ray_client, + _stream_client_output, +) +from ray.tune.result_grid import ResultGrid +from ray.tune.trainable import Trainable +from ray.tune.tune_config import TuneConfig +from ray.util import PublicAPI + +logger = logging.getLogger(__name__) + +if TYPE_CHECKING: + from ray.train.base_trainer import BaseTrainer + +ClientActorHandle = Any + +# try: +# # Breaks lint right now. +# from ray.util.client.common import ClientActorHandle +# except Exception: +# pass + +# The magic key that is used when instantiating Tuner during resume. +_TUNER_INTERNAL = "_tuner_internal" +_SELF = "self" + + +@PublicAPI(stability="beta") +class Tuner: + """Tuner is the recommended way of launching hyperparameter tuning jobs with Ray Tune. + + Args: + trainable: The trainable to be tuned. + param_space: Search space of the tuning job. + One thing to note is that both preprocessor and dataset can be tuned here. + tune_config: Tuning algorithm specific configs. + Refer to ray.tune.tune_config.TuneConfig for more info. + run_config: Runtime configuration that is specific to individual trials. + If passed, this will overwrite the run config passed to the Trainer, + if applicable. Refer to ray.train.RunConfig for more info. + + Usage pattern: + + .. code-block:: python + + from sklearn.datasets import load_breast_cancer + + from ray import tune + from ray.data import from_pandas + from ray.train import RunConfig, ScalingConfig + from ray.train.xgboost import XGBoostTrainer + from ray.tune.tuner import Tuner + + def get_dataset(): + data_raw = load_breast_cancer(as_frame=True) + dataset_df = data_raw["data"] + dataset_df["target"] = data_raw["target"] + dataset = from_pandas(dataset_df) + return dataset + + trainer = XGBoostTrainer( + label_column="target", + params={}, + datasets={"train": get_dataset()}, + ) + + param_space = { + "scaling_config": ScalingConfig( + num_workers=tune.grid_search([2, 4]), + resources_per_worker={ + "CPU": tune.grid_search([1, 2]), + }, + ), + # You can even grid search various datasets in Tune. + # "datasets": { + # "train": tune.grid_search( + # [ds1, ds2] + # ), + # }, + "params": { + "objective": "binary:logistic", + "tree_method": "approx", + "eval_metric": ["logloss", "error"], + "eta": tune.loguniform(1e-4, 1e-1), + "subsample": tune.uniform(0.5, 1.0), + "max_depth": tune.randint(1, 9), + }, + } + tuner = Tuner(trainable=trainer, param_space=param_space, + run_config=RunConfig(name="my_tune_run")) + results = tuner.fit() + + To retry a failed tune run, you can then do + + .. code-block:: python + + tuner = Tuner.restore(results.experiment_path, trainable=trainer) + tuner.fit() + + ``results.experiment_path`` can be retrieved from the + :ref:`ResultGrid object `. It can + also be easily seen in the log output from your first run. + + """ + + # One of the following is assigned. + _local_tuner: Optional[TunerInternal] # Only used in none ray client mode. + _remote_tuner: Optional[ClientActorHandle] # Only used in ray client mode. + + def __init__( + self, + trainable: Optional[ + Union[str, Callable, Type[Trainable], "BaseTrainer"] + ] = None, + *, + param_space: Optional[Dict[str, Any]] = None, + tune_config: Optional[TuneConfig] = None, + run_config: Optional[RunConfig] = None, + # This is internal only arg. + # Only for dogfooding purposes. We can slowly promote these args + # to RunConfig or TuneConfig as needed. + # TODO(xwjiang): Remove this later. + _tuner_kwargs: Optional[Dict] = None, + _tuner_internal: Optional[TunerInternal] = None, + _entrypoint: AirEntrypoint = AirEntrypoint.TUNER, + ): + """Configure and construct a tune run.""" + kwargs = locals().copy() + self._is_ray_client = ray.util.client.ray.is_connected() + if self._is_ray_client: + _run_config = run_config or RunConfig() + if get_air_verbosity(_run_config.verbose) is not None: + logger.info( + "[output] This uses the legacy output and progress reporter, " + "as Ray client is not supported by the new engine. " + "For more information, see " + "https://github.com/ray-project/ray/issues/36949" + ) + + if _tuner_internal: + if not self._is_ray_client: + self._local_tuner = kwargs[_TUNER_INTERNAL] + else: + self._remote_tuner = kwargs[_TUNER_INTERNAL] + else: + kwargs.pop(_TUNER_INTERNAL, None) + kwargs.pop(_SELF, None) + if not self._is_ray_client: + self._local_tuner = TunerInternal(**kwargs) + else: + self._remote_tuner = _force_on_current_node( + ray.remote(num_cpus=0)(TunerInternal) + ).remote(**kwargs) + + @classmethod + def restore( + cls, + path: str, + trainable: Union[str, Callable, Type[Trainable], "BaseTrainer"], + resume_unfinished: bool = True, + resume_errored: bool = False, + restart_errored: bool = False, + param_space: Optional[Dict[str, Any]] = None, + storage_filesystem: Optional[pyarrow.fs.FileSystem] = None, + _resume_config: Optional[ResumeConfig] = None, + ) -> "Tuner": + """Restores Tuner after a previously failed run. + + All trials from the existing run will be added to the result table. The + argument flags control how existing but unfinished or errored trials are + resumed. + + Finished trials are always added to the overview table. They will not be + resumed. + + Unfinished trials can be controlled with the ``resume_unfinished`` flag. + If ``True`` (default), they will be continued. If ``False``, they will + be added as terminated trials (even if they were only created and never + trained). + + Errored trials can be controlled with the ``resume_errored`` and + ``restart_errored`` flags. The former will resume errored trials from + their latest checkpoints. The latter will restart errored trials from + scratch and prevent loading their last checkpoints. + + .. note:: + + Restoring an experiment from a path that's pointing to a *different* + location than the original experiment path is supported. + However, Ray Tune assumes that the full experiment directory is available + (including checkpoints) so that it's possible to resume trials from their + latest state. + + For example, if the original experiment path was run locally, + then the results are uploaded to cloud storage, Ray Tune expects the full + contents to be available in cloud storage if attempting to resume + via ``Tuner.restore("s3://...")``. The restored run will continue + writing results to the same cloud storage location. + + Args: + path: The local or remote path of the experiment directory + for an interrupted or failed run. + Note that an experiment where all trials finished will not be resumed. + This information could be easily located near the end of the + console output of previous run. + trainable: The trainable to use upon resuming the experiment. + This should be the same trainable that was used to initialize + the original Tuner. + param_space: The same `param_space` that was passed to + the original Tuner. This can be optionally re-specified due + to the `param_space` potentially containing Ray object + references (tuning over Datasets or tuning over + several `ray.put` object references). **Tune expects the + `param_space` to be unmodified**, and the only part that + will be used during restore are the updated object references. + Changing the hyperparameter search space then resuming is NOT + supported by this API. + resume_unfinished: If True, will continue to run unfinished trials. + resume_errored: If True, will re-schedule errored trials and try to + restore from their latest checkpoints. + restart_errored: If True, will re-schedule errored trials but force + restarting them from scratch (no checkpoint will be loaded). + storage_filesystem: Custom ``pyarrow.fs.FileSystem`` + corresponding to the ``path``. This may be necessary if the original + experiment passed in a custom filesystem. + _resume_config: [Experimental] Config object that controls how to resume + trials of different statuses. Can be used as a substitute to + `resume_*` and `restart_*` flags above. + """ + unfinished = ( + ResumeConfig.ResumeType.RESUME + if resume_unfinished + else ResumeConfig.ResumeType.SKIP + ) + errored = ResumeConfig.ResumeType.SKIP + if resume_errored: + errored = ResumeConfig.ResumeType.RESUME + elif restart_errored: + errored = ResumeConfig.ResumeType.RESTART + + resume_config = _resume_config or ResumeConfig( + unfinished=unfinished, errored=errored + ) + + if not ray.util.client.ray.is_connected(): + tuner_internal = TunerInternal( + restore_path=path, + resume_config=resume_config, + trainable=trainable, + param_space=param_space, + storage_filesystem=storage_filesystem, + ) + return Tuner(_tuner_internal=tuner_internal) + else: + tuner_internal = _force_on_current_node( + ray.remote(num_cpus=0)(TunerInternal) + ).remote( + restore_path=path, + resume_config=resume_config, + trainable=trainable, + param_space=param_space, + storage_filesystem=storage_filesystem, + ) + return Tuner(_tuner_internal=tuner_internal) + + @classmethod + def can_restore( + cls, + path: Union[str, os.PathLike], + storage_filesystem: Optional[pyarrow.fs.FileSystem] = None, + ) -> bool: + """Checks whether a given directory contains a restorable Tune experiment. + + Usage Pattern: + + Use this utility to switch between starting a new Tune experiment + and restoring when possible. This is useful for experiment fault-tolerance + when re-running a failed tuning script. + + .. code-block:: python + + import os + from ray.tune import Tuner + from ray.train import RunConfig + + def train_fn(config): + # Make sure to implement checkpointing so that progress gets + # saved on restore. + pass + + name = "exp_name" + storage_path = os.path.expanduser("~/ray_results") + exp_dir = os.path.join(storage_path, name) + + if Tuner.can_restore(exp_dir): + tuner = Tuner.restore(exp_dir, trainable=train_fn, resume_errored=True) + else: + tuner = Tuner( + train_fn, + run_config=RunConfig(name=name, storage_path=storage_path), + ) + tuner.fit() + + Args: + path: The path to the experiment directory of the Tune experiment. + This can be either a local directory or a remote URI + (e.g. s3://bucket/exp_name). + + Returns: + bool: True if this path exists and contains the Tuner state to resume from + """ + fs, fs_path = get_fs_and_path(path, storage_filesystem) + return _exists_at_fs_path(fs, Path(fs_path, _TUNER_PKL).as_posix()) + + def _prepare_remote_tuner_for_jupyter_progress_reporting(self): + run_config: RunConfig = ray.get(self._remote_tuner.get_run_config.remote()) + progress_reporter, string_queue = _prepare_progress_reporter_for_ray_client( + run_config.progress_reporter, run_config.verbose + ) + run_config.progress_reporter = progress_reporter + ray.get( + self._remote_tuner.set_run_config_and_remote_string_queue.remote( + run_config, string_queue + ) + ) + + return progress_reporter, string_queue + + def fit(self) -> ResultGrid: + """Executes hyperparameter tuning job as configured and returns result. + + Failure handling: + For the kind of exception that happens during the execution of a trial, + one may inspect it together with stacktrace through the returned result grid. + See ``ResultGrid`` for reference. Each trial may fail up to a certain number. + This is configured by ``RunConfig.FailureConfig.max_failures``. + + Exception that happens beyond trials will be thrown by this method as well. + In such cases, there will be instruction like the following printed out + at the end of console output to inform users on how to resume. + + Please use `Tuner.restore` to resume. + + .. code-block:: python + + import os + from ray.tune import Tuner + + trainable = ... + + tuner = Tuner.restore( + os.path.expanduser("~/ray_results/tuner_resume"), + trainable=trainable + ) + tuner.fit() + + Raises: + RayTaskError: If user-provided trainable raises an exception + """ + + if not self._is_ray_client: + return self._local_tuner.fit() + else: + ( + progress_reporter, + string_queue, + ) = self._prepare_remote_tuner_for_jupyter_progress_reporting() + fit_future = self._remote_tuner.fit.remote() + _stream_client_output( + fit_future, + progress_reporter, + string_queue, + ) + return ray.get(fit_future) + + def get_results(self) -> ResultGrid: + """Get results of a hyperparameter tuning run. + + This method returns the same results as :meth:`~ray.tune.Tuner.fit` + and can be used to retrieve the results after restoring a tuner without + calling ``fit()`` again. + + If the tuner has not been fit before, an error will be raised. + + .. code-block:: python + + from ray.tune import Tuner + + # `trainable` is what was passed in to the original `Tuner` + tuner = Tuner.restore("/path/to/experiment', trainable=trainable) + results = tuner.get_results() + + Returns: + Result grid of a previously fitted tuning run. + + """ + if not self._is_ray_client: + return self._local_tuner.get_results() + else: + ( + progress_reporter, + string_queue, + ) = self._prepare_remote_tuner_for_jupyter_progress_reporting() + get_results_future = self._remote_tuner.get_results.remote() + _stream_client_output( + get_results_future, + progress_reporter, + string_queue, + ) + return ray.get(get_results_future) + + def __getattribute__(self, item): + if item == "restore": + raise AttributeError( + "`Tuner.restore()` is a classmethod and cannot be called on an " + "instance. Use `tuner = Tuner.restore(...)` to instantiate the " + "Tuner instead." + ) + return super().__getattribute__(item)