Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .venv/lib/python3.11/site-packages/ray/rllib/policy/__pycache__/dynamic_tf_policy.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/policy/__pycache__/dynamic_tf_policy_v2.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/policy/__pycache__/eager_tf_policy.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/policy/__pycache__/eager_tf_policy_v2.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/policy/__pycache__/policy.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/policy/__pycache__/policy_map.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/policy/__pycache__/policy_template.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/policy/__pycache__/rnn_sequencing.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/policy/__pycache__/sample_batch.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/policy/__pycache__/tf_mixins.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/policy/__pycache__/tf_policy_template.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/policy/__pycache__/torch_policy.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/policy/__pycache__/view_requirement.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/tune/__init__.py +115 -0
- .venv/lib/python3.11/site-packages/ray/tune/automl/__init__.py +1 -0
- .venv/lib/python3.11/site-packages/ray/tune/automl/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/tune/callback.py +512 -0
- .venv/lib/python3.11/site-packages/ray/tune/constants.py +32 -0
- .venv/lib/python3.11/site-packages/ray/tune/context.py +113 -0
- .venv/lib/python3.11/site-packages/ray/tune/error.py +48 -0
- .venv/lib/python3.11/site-packages/ray/tune/examples/cifar10_pytorch.py +285 -0
- .venv/lib/python3.11/site-packages/ray/tune/examples/lightgbm_example.py +105 -0
- .venv/lib/python3.11/site-packages/ray/tune/execution/__init__.py +0 -0
- .venv/lib/python3.11/site-packages/ray/tune/execution/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/tune/execution/__pycache__/class_cache.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/tune/execution/__pycache__/cluster_info.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/tune/execution/__pycache__/experiment_state.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/tune/execution/__pycache__/insufficient_resources_manager.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/tune/execution/__pycache__/placement_groups.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/tune/execution/class_cache.py +68 -0
- .venv/lib/python3.11/site-packages/ray/tune/execution/cluster_info.py +12 -0
- .venv/lib/python3.11/site-packages/ray/tune/execution/experiment_state.py +287 -0
- .venv/lib/python3.11/site-packages/ray/tune/execution/insufficient_resources_manager.py +167 -0
- .venv/lib/python3.11/site-packages/ray/tune/execution/placement_groups.py +131 -0
- .venv/lib/python3.11/site-packages/ray/tune/execution/tune_controller.py +2181 -0
- .venv/lib/python3.11/site-packages/ray/tune/experiment/__init__.py +4 -0
- .venv/lib/python3.11/site-packages/ray/tune/experiment/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/tune/experiment/__pycache__/config_parser.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/tune/experiment/__pycache__/experiment.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/tune/experiment/__pycache__/trial.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/tune/experiment/config_parser.py +210 -0
- .venv/lib/python3.11/site-packages/ray/tune/experiment/experiment.py +445 -0
- .venv/lib/python3.11/site-packages/ray/tune/experiment/trial.py +1073 -0
- .venv/lib/python3.11/site-packages/ray/tune/integration/__init__.py +0 -0
- .venv/lib/python3.11/site-packages/ray/tune/integration/__pycache__/keras.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/tune/integration/__pycache__/lightgbm.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/tune/integration/__pycache__/pytorch_lightning.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/tune/integration/__pycache__/ray_train.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/tune/integration/__pycache__/xgboost.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/tune/integration/keras.py +28 -0
.venv/lib/python3.11/site-packages/ray/rllib/policy/__pycache__/dynamic_tf_policy.cpython-311.pyc
ADDED
|
Binary file (60.1 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/policy/__pycache__/dynamic_tf_policy_v2.cpython-311.pyc
ADDED
|
Binary file (45 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/policy/__pycache__/eager_tf_policy.cpython-311.pyc
ADDED
|
Binary file (49.2 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/policy/__pycache__/eager_tf_policy_v2.cpython-311.pyc
ADDED
|
Binary file (42.7 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/policy/__pycache__/policy.cpython-311.pyc
ADDED
|
Binary file (74.2 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/policy/__pycache__/policy_map.cpython-311.pyc
ADDED
|
Binary file (13.4 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/policy/__pycache__/policy_template.cpython-311.pyc
ADDED
|
Binary file (22.4 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/policy/__pycache__/rnn_sequencing.cpython-311.pyc
ADDED
|
Binary file (29.7 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/policy/__pycache__/sample_batch.cpython-311.pyc
ADDED
|
Binary file (77.5 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/policy/__pycache__/tf_mixins.cpython-311.pyc
ADDED
|
Binary file (20.4 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/policy/__pycache__/tf_policy_template.cpython-311.pyc
ADDED
|
Binary file (18.2 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/policy/__pycache__/torch_policy.cpython-311.pyc
ADDED
|
Binary file (58.7 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/policy/__pycache__/view_requirement.cpython-311.pyc
ADDED
|
Binary file (7.91 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/tune/__init__.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# isort: off
|
| 2 |
+
# Try import ray[tune] core requirements (defined in setup.py)
|
| 3 |
+
try:
|
| 4 |
+
import fsspec # noqa: F401
|
| 5 |
+
import pandas # noqa: F401
|
| 6 |
+
import pyarrow # noqa: F401
|
| 7 |
+
import requests # noqa: F401
|
| 8 |
+
except ImportError as exc:
|
| 9 |
+
raise ImportError(
|
| 10 |
+
"Can't import ray.tune as some dependencies are missing. "
|
| 11 |
+
'Run `pip install "ray[tune]"` to fix.'
|
| 12 |
+
) from exc
|
| 13 |
+
# isort: on
|
| 14 |
+
|
| 15 |
+
from ray.air.result import Result
|
| 16 |
+
from ray.tune.analysis import ExperimentAnalysis
|
| 17 |
+
from ray.tune.callback import Callback
|
| 18 |
+
from ray.tune.context import TuneContext, get_context
|
| 19 |
+
from ray.tune.error import TuneError
|
| 20 |
+
from ray.tune.execution.placement_groups import PlacementGroupFactory
|
| 21 |
+
from ray.tune.experiment import Experiment
|
| 22 |
+
from ray.tune.impl.config import CheckpointConfig, FailureConfig, RunConfig
|
| 23 |
+
from ray.tune.progress_reporter import (
|
| 24 |
+
CLIReporter,
|
| 25 |
+
JupyterNotebookReporter,
|
| 26 |
+
ProgressReporter,
|
| 27 |
+
)
|
| 28 |
+
from ray.tune.registry import register_env, register_trainable
|
| 29 |
+
from ray.tune.result_grid import ResultGrid
|
| 30 |
+
from ray.tune.schedulers import create_scheduler
|
| 31 |
+
from ray.tune.search import create_searcher, grid_search
|
| 32 |
+
from ray.tune.search.sample import (
|
| 33 |
+
choice,
|
| 34 |
+
lograndint,
|
| 35 |
+
loguniform,
|
| 36 |
+
qlograndint,
|
| 37 |
+
qloguniform,
|
| 38 |
+
qrandint,
|
| 39 |
+
qrandn,
|
| 40 |
+
quniform,
|
| 41 |
+
randint,
|
| 42 |
+
randn,
|
| 43 |
+
sample_from,
|
| 44 |
+
uniform,
|
| 45 |
+
)
|
| 46 |
+
from ray.tune.stopper import Stopper
|
| 47 |
+
from ray.tune.syncer import SyncConfig
|
| 48 |
+
from ray.tune.trainable import Trainable
|
| 49 |
+
from ray.tune.trainable.trainable_fn_utils import Checkpoint, get_checkpoint, report
|
| 50 |
+
from ray.tune.trainable.util import with_parameters, with_resources
|
| 51 |
+
from ray.tune.tune import run, run_experiments
|
| 52 |
+
from ray.tune.tune_config import ResumeConfig, TuneConfig
|
| 53 |
+
from ray.tune.tuner import Tuner
|
| 54 |
+
|
| 55 |
+
__all__ = [
|
| 56 |
+
"Trainable",
|
| 57 |
+
"Callback",
|
| 58 |
+
"TuneError",
|
| 59 |
+
"grid_search",
|
| 60 |
+
"register_env",
|
| 61 |
+
"register_trainable",
|
| 62 |
+
"run",
|
| 63 |
+
"run_experiments",
|
| 64 |
+
"with_parameters",
|
| 65 |
+
"with_resources",
|
| 66 |
+
"Stopper",
|
| 67 |
+
"Experiment",
|
| 68 |
+
"sample_from",
|
| 69 |
+
"uniform",
|
| 70 |
+
"quniform",
|
| 71 |
+
"choice",
|
| 72 |
+
"randint",
|
| 73 |
+
"lograndint",
|
| 74 |
+
"qrandint",
|
| 75 |
+
"qlograndint",
|
| 76 |
+
"randn",
|
| 77 |
+
"qrandn",
|
| 78 |
+
"loguniform",
|
| 79 |
+
"qloguniform",
|
| 80 |
+
"ExperimentAnalysis",
|
| 81 |
+
"CLIReporter",
|
| 82 |
+
"JupyterNotebookReporter",
|
| 83 |
+
"ProgressReporter",
|
| 84 |
+
"ResultGrid",
|
| 85 |
+
"create_searcher",
|
| 86 |
+
"create_scheduler",
|
| 87 |
+
"PlacementGroupFactory",
|
| 88 |
+
"Tuner",
|
| 89 |
+
"TuneConfig",
|
| 90 |
+
"ResumeConfig",
|
| 91 |
+
"RunConfig",
|
| 92 |
+
"CheckpointConfig",
|
| 93 |
+
"FailureConfig",
|
| 94 |
+
"Result",
|
| 95 |
+
"Checkpoint",
|
| 96 |
+
"get_checkpoint",
|
| 97 |
+
"report",
|
| 98 |
+
"get_context",
|
| 99 |
+
"TuneContext",
|
| 100 |
+
# TODO(justinvyu): [Deprecated]
|
| 101 |
+
"SyncConfig",
|
| 102 |
+
]
|
| 103 |
+
|
| 104 |
+
report.__module__ = "ray.tune"
|
| 105 |
+
get_checkpoint.__module__ = "ray.tune"
|
| 106 |
+
get_context.__module__ = "ray.tune"
|
| 107 |
+
TuneContext.__module__ = "ray.tune"
|
| 108 |
+
Checkpoint.__module__ = "ray.tune"
|
| 109 |
+
Result.__module__ = "ray.tune"
|
| 110 |
+
RunConfig.__module__ = "ray.tune"
|
| 111 |
+
CheckpointConfig.__module__ = "ray.tune"
|
| 112 |
+
FailureConfig.__module__ = "ray.tune"
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
# DO NOT ADD ANYTHING AFTER THIS LINE.
|
.venv/lib/python3.11/site-packages/ray/tune/automl/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
raise DeprecationWarning("`ray.tune.automl` is deprecated in Ray 2.6.")
|
.venv/lib/python3.11/site-packages/ray/tune/automl/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (275 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/tune/callback.py
ADDED
|
@@ -0,0 +1,512 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import glob
|
| 2 |
+
import warnings
|
| 3 |
+
from abc import ABCMeta
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
| 6 |
+
|
| 7 |
+
import ray.tune
|
| 8 |
+
from ray.tune.utils.util import _atomic_save, _load_newest_checkpoint
|
| 9 |
+
from ray.util.annotations import DeveloperAPI, PublicAPI
|
| 10 |
+
|
| 11 |
+
if TYPE_CHECKING:
|
| 12 |
+
from ray.tune.experiment import Trial
|
| 13 |
+
from ray.tune.stopper import Stopper
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class _CallbackMeta(ABCMeta):
|
| 17 |
+
"""A helper metaclass to ensure container classes (e.g. CallbackList) have
|
| 18 |
+
implemented all the callback methods (e.g. `on_*`).
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
def __new__(mcs, name: str, bases: Tuple[type], attrs: Dict[str, Any]) -> type:
|
| 22 |
+
cls = super().__new__(mcs, name, bases, attrs)
|
| 23 |
+
|
| 24 |
+
if mcs.need_check(cls, name, bases, attrs):
|
| 25 |
+
mcs.check(cls, name, bases, attrs)
|
| 26 |
+
|
| 27 |
+
return cls
|
| 28 |
+
|
| 29 |
+
@classmethod
|
| 30 |
+
def need_check(
|
| 31 |
+
mcs, cls: type, name: str, bases: Tuple[type], attrs: Dict[str, Any]
|
| 32 |
+
) -> bool:
|
| 33 |
+
return attrs.get("IS_CALLBACK_CONTAINER", False)
|
| 34 |
+
|
| 35 |
+
@classmethod
|
| 36 |
+
def check(
|
| 37 |
+
mcs, cls: type, name: str, bases: Tuple[type], attrs: Dict[str, Any]
|
| 38 |
+
) -> None:
|
| 39 |
+
methods = set()
|
| 40 |
+
for base in bases:
|
| 41 |
+
methods.update(
|
| 42 |
+
attr_name
|
| 43 |
+
for attr_name, attr in vars(base).items()
|
| 44 |
+
if mcs.need_override_by_subclass(attr_name, attr)
|
| 45 |
+
)
|
| 46 |
+
overridden = {
|
| 47 |
+
attr_name
|
| 48 |
+
for attr_name, attr in attrs.items()
|
| 49 |
+
if mcs.need_override_by_subclass(attr_name, attr)
|
| 50 |
+
}
|
| 51 |
+
missing = methods.difference(overridden)
|
| 52 |
+
if missing:
|
| 53 |
+
raise TypeError(
|
| 54 |
+
f"Found missing callback method: {missing} "
|
| 55 |
+
f"in class {cls.__module__}.{cls.__qualname__}."
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
@classmethod
|
| 59 |
+
def need_override_by_subclass(mcs, attr_name: str, attr: Any) -> bool:
|
| 60 |
+
return (
|
| 61 |
+
(
|
| 62 |
+
attr_name.startswith("on_")
|
| 63 |
+
and not attr_name.startswith("on_trainer_init")
|
| 64 |
+
)
|
| 65 |
+
or attr_name == "setup"
|
| 66 |
+
) and callable(attr)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
@PublicAPI(stability="beta")
|
| 70 |
+
class Callback(metaclass=_CallbackMeta):
|
| 71 |
+
"""Tune base callback that can be extended and passed to a ``TrialRunner``
|
| 72 |
+
|
| 73 |
+
Tune callbacks are called from within the ``TrialRunner`` class. There are
|
| 74 |
+
several hooks that can be used, all of which are found in the submethod
|
| 75 |
+
definitions of this base class.
|
| 76 |
+
|
| 77 |
+
The parameters passed to the ``**info`` dict vary between hooks. The
|
| 78 |
+
parameters passed are described in the docstrings of the methods.
|
| 79 |
+
|
| 80 |
+
This example will print a metric each time a result is received:
|
| 81 |
+
|
| 82 |
+
.. testcode::
|
| 83 |
+
|
| 84 |
+
from ray import train, tune
|
| 85 |
+
from ray.tune import Callback
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
class MyCallback(Callback):
|
| 89 |
+
def on_trial_result(self, iteration, trials, trial, result,
|
| 90 |
+
**info):
|
| 91 |
+
print(f"Got result: {result['metric']}")
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def train_func(config):
|
| 95 |
+
for i in range(10):
|
| 96 |
+
tune.report(metric=i)
|
| 97 |
+
|
| 98 |
+
tuner = tune.Tuner(
|
| 99 |
+
train_func,
|
| 100 |
+
run_config=train.RunConfig(
|
| 101 |
+
callbacks=[MyCallback()]
|
| 102 |
+
)
|
| 103 |
+
)
|
| 104 |
+
tuner.fit()
|
| 105 |
+
|
| 106 |
+
.. testoutput::
|
| 107 |
+
:hide:
|
| 108 |
+
|
| 109 |
+
...
|
| 110 |
+
"""
|
| 111 |
+
|
| 112 |
+
# File templates for any artifacts written by this callback
|
| 113 |
+
# These files should live in the `trial.local_path` for each trial.
|
| 114 |
+
# TODO(ml-team): Make this more visible to users to override. Internal use for now.
|
| 115 |
+
_SAVED_FILE_TEMPLATES = []
|
| 116 |
+
|
| 117 |
+
# arguments here match Experiment.public_spec
|
| 118 |
+
def setup(
|
| 119 |
+
self,
|
| 120 |
+
stop: Optional["Stopper"] = None,
|
| 121 |
+
num_samples: Optional[int] = None,
|
| 122 |
+
total_num_samples: Optional[int] = None,
|
| 123 |
+
**info,
|
| 124 |
+
):
|
| 125 |
+
"""Called once at the very beginning of training.
|
| 126 |
+
|
| 127 |
+
Any Callback setup should be added here (setting environment
|
| 128 |
+
variables, etc.)
|
| 129 |
+
|
| 130 |
+
Arguments:
|
| 131 |
+
stop: Stopping criteria.
|
| 132 |
+
If ``time_budget_s`` was passed to ``train.RunConfig``, a
|
| 133 |
+
``TimeoutStopper`` will be passed here, either by itself
|
| 134 |
+
or as a part of a ``CombinedStopper``.
|
| 135 |
+
num_samples: Number of times to sample from the
|
| 136 |
+
hyperparameter space. Defaults to 1. If `grid_search` is
|
| 137 |
+
provided as an argument, the grid will be repeated
|
| 138 |
+
`num_samples` of times. If this is -1, (virtually) infinite
|
| 139 |
+
samples are generated until a stopping condition is met.
|
| 140 |
+
total_num_samples: Total number of samples factoring
|
| 141 |
+
in grid search samplers.
|
| 142 |
+
**info: Kwargs dict for forward compatibility.
|
| 143 |
+
"""
|
| 144 |
+
pass
|
| 145 |
+
|
| 146 |
+
def on_step_begin(self, iteration: int, trials: List["Trial"], **info):
|
| 147 |
+
"""Called at the start of each tuning loop step.
|
| 148 |
+
|
| 149 |
+
Arguments:
|
| 150 |
+
iteration: Number of iterations of the tuning loop.
|
| 151 |
+
trials: List of trials.
|
| 152 |
+
**info: Kwargs dict for forward compatibility.
|
| 153 |
+
"""
|
| 154 |
+
pass
|
| 155 |
+
|
| 156 |
+
def on_step_end(self, iteration: int, trials: List["Trial"], **info):
|
| 157 |
+
"""Called at the end of each tuning loop step.
|
| 158 |
+
|
| 159 |
+
The iteration counter is increased before this hook is called.
|
| 160 |
+
|
| 161 |
+
Arguments:
|
| 162 |
+
iteration: Number of iterations of the tuning loop.
|
| 163 |
+
trials: List of trials.
|
| 164 |
+
**info: Kwargs dict for forward compatibility.
|
| 165 |
+
"""
|
| 166 |
+
pass
|
| 167 |
+
|
| 168 |
+
def on_trial_start(
|
| 169 |
+
self, iteration: int, trials: List["Trial"], trial: "Trial", **info
|
| 170 |
+
):
|
| 171 |
+
"""Called after starting a trial instance.
|
| 172 |
+
|
| 173 |
+
Arguments:
|
| 174 |
+
iteration: Number of iterations of the tuning loop.
|
| 175 |
+
trials: List of trials.
|
| 176 |
+
trial: Trial that just has been started.
|
| 177 |
+
**info: Kwargs dict for forward compatibility.
|
| 178 |
+
|
| 179 |
+
"""
|
| 180 |
+
pass
|
| 181 |
+
|
| 182 |
+
def on_trial_restore(
|
| 183 |
+
self, iteration: int, trials: List["Trial"], trial: "Trial", **info
|
| 184 |
+
):
|
| 185 |
+
"""Called after restoring a trial instance.
|
| 186 |
+
|
| 187 |
+
Arguments:
|
| 188 |
+
iteration: Number of iterations of the tuning loop.
|
| 189 |
+
trials: List of trials.
|
| 190 |
+
trial: Trial that just has been restored.
|
| 191 |
+
**info: Kwargs dict for forward compatibility.
|
| 192 |
+
"""
|
| 193 |
+
pass
|
| 194 |
+
|
| 195 |
+
def on_trial_save(
|
| 196 |
+
self, iteration: int, trials: List["Trial"], trial: "Trial", **info
|
| 197 |
+
):
|
| 198 |
+
"""Called after receiving a checkpoint from a trial.
|
| 199 |
+
|
| 200 |
+
Arguments:
|
| 201 |
+
iteration: Number of iterations of the tuning loop.
|
| 202 |
+
trials: List of trials.
|
| 203 |
+
trial: Trial that just saved a checkpoint.
|
| 204 |
+
**info: Kwargs dict for forward compatibility.
|
| 205 |
+
"""
|
| 206 |
+
pass
|
| 207 |
+
|
| 208 |
+
def on_trial_result(
|
| 209 |
+
self,
|
| 210 |
+
iteration: int,
|
| 211 |
+
trials: List["Trial"],
|
| 212 |
+
trial: "Trial",
|
| 213 |
+
result: Dict,
|
| 214 |
+
**info,
|
| 215 |
+
):
|
| 216 |
+
"""Called after receiving a result from a trial.
|
| 217 |
+
|
| 218 |
+
The search algorithm and scheduler are notified before this
|
| 219 |
+
hook is called.
|
| 220 |
+
|
| 221 |
+
Arguments:
|
| 222 |
+
iteration: Number of iterations of the tuning loop.
|
| 223 |
+
trials: List of trials.
|
| 224 |
+
trial: Trial that just sent a result.
|
| 225 |
+
result: Result that the trial sent.
|
| 226 |
+
**info: Kwargs dict for forward compatibility.
|
| 227 |
+
"""
|
| 228 |
+
pass
|
| 229 |
+
|
| 230 |
+
def on_trial_complete(
|
| 231 |
+
self, iteration: int, trials: List["Trial"], trial: "Trial", **info
|
| 232 |
+
):
|
| 233 |
+
"""Called after a trial instance completed.
|
| 234 |
+
|
| 235 |
+
The search algorithm and scheduler are notified before this
|
| 236 |
+
hook is called.
|
| 237 |
+
|
| 238 |
+
Arguments:
|
| 239 |
+
iteration: Number of iterations of the tuning loop.
|
| 240 |
+
trials: List of trials.
|
| 241 |
+
trial: Trial that just has been completed.
|
| 242 |
+
**info: Kwargs dict for forward compatibility.
|
| 243 |
+
"""
|
| 244 |
+
pass
|
| 245 |
+
|
| 246 |
+
def on_trial_recover(
|
| 247 |
+
self, iteration: int, trials: List["Trial"], trial: "Trial", **info
|
| 248 |
+
):
|
| 249 |
+
"""Called after a trial instance failed (errored) but the trial is scheduled
|
| 250 |
+
for retry.
|
| 251 |
+
|
| 252 |
+
The search algorithm and scheduler are not notified.
|
| 253 |
+
|
| 254 |
+
Arguments:
|
| 255 |
+
iteration: Number of iterations of the tuning loop.
|
| 256 |
+
trials: List of trials.
|
| 257 |
+
trial: Trial that just has errored.
|
| 258 |
+
**info: Kwargs dict for forward compatibility.
|
| 259 |
+
"""
|
| 260 |
+
pass
|
| 261 |
+
|
| 262 |
+
def on_trial_error(
|
| 263 |
+
self, iteration: int, trials: List["Trial"], trial: "Trial", **info
|
| 264 |
+
):
|
| 265 |
+
"""Called after a trial instance failed (errored).
|
| 266 |
+
|
| 267 |
+
The search algorithm and scheduler are notified before this
|
| 268 |
+
hook is called.
|
| 269 |
+
|
| 270 |
+
Arguments:
|
| 271 |
+
iteration: Number of iterations of the tuning loop.
|
| 272 |
+
trials: List of trials.
|
| 273 |
+
trial: Trial that just has errored.
|
| 274 |
+
**info: Kwargs dict for forward compatibility.
|
| 275 |
+
"""
|
| 276 |
+
pass
|
| 277 |
+
|
| 278 |
+
def on_checkpoint(
|
| 279 |
+
self,
|
| 280 |
+
iteration: int,
|
| 281 |
+
trials: List["Trial"],
|
| 282 |
+
trial: "Trial",
|
| 283 |
+
checkpoint: "ray.tune.Checkpoint",
|
| 284 |
+
**info,
|
| 285 |
+
):
|
| 286 |
+
"""Called after a trial saved a checkpoint with Tune.
|
| 287 |
+
|
| 288 |
+
Arguments:
|
| 289 |
+
iteration: Number of iterations of the tuning loop.
|
| 290 |
+
trials: List of trials.
|
| 291 |
+
trial: Trial that just has errored.
|
| 292 |
+
checkpoint: Checkpoint object that has been saved
|
| 293 |
+
by the trial.
|
| 294 |
+
**info: Kwargs dict for forward compatibility.
|
| 295 |
+
"""
|
| 296 |
+
pass
|
| 297 |
+
|
| 298 |
+
def on_experiment_end(self, trials: List["Trial"], **info):
|
| 299 |
+
"""Called after experiment is over and all trials have concluded.
|
| 300 |
+
|
| 301 |
+
Arguments:
|
| 302 |
+
trials: List of trials.
|
| 303 |
+
**info: Kwargs dict for forward compatibility.
|
| 304 |
+
"""
|
| 305 |
+
pass
|
| 306 |
+
|
| 307 |
+
def get_state(self) -> Optional[Dict]:
|
| 308 |
+
"""Get the state of the callback.
|
| 309 |
+
|
| 310 |
+
This method should be implemented by subclasses to return a dictionary
|
| 311 |
+
representation of the object's current state.
|
| 312 |
+
|
| 313 |
+
This is called automatically by Tune to periodically checkpoint callback state.
|
| 314 |
+
Upon :ref:`Tune experiment restoration <tune-experiment-level-fault-tolerance>`,
|
| 315 |
+
callback state will be restored via :meth:`~ray.tune.Callback.set_state`.
|
| 316 |
+
|
| 317 |
+
.. testcode::
|
| 318 |
+
|
| 319 |
+
from typing import Dict, List, Optional
|
| 320 |
+
|
| 321 |
+
from ray.tune import Callback
|
| 322 |
+
from ray.tune.experiment import Trial
|
| 323 |
+
|
| 324 |
+
class MyCallback(Callback):
|
| 325 |
+
def __init__(self):
|
| 326 |
+
self._trial_ids = set()
|
| 327 |
+
|
| 328 |
+
def on_trial_start(
|
| 329 |
+
self, iteration: int, trials: List["Trial"], trial: "Trial", **info
|
| 330 |
+
):
|
| 331 |
+
self._trial_ids.add(trial.trial_id)
|
| 332 |
+
|
| 333 |
+
def get_state(self) -> Optional[Dict]:
|
| 334 |
+
return {"trial_ids": self._trial_ids.copy()}
|
| 335 |
+
|
| 336 |
+
def set_state(self, state: Dict) -> Optional[Dict]:
|
| 337 |
+
self._trial_ids = state["trial_ids"]
|
| 338 |
+
|
| 339 |
+
Returns:
|
| 340 |
+
dict: State of the callback. Should be `None` if the callback does not
|
| 341 |
+
have any state to save (this is the default).
|
| 342 |
+
"""
|
| 343 |
+
return None
|
| 344 |
+
|
| 345 |
+
def set_state(self, state: Dict):
|
| 346 |
+
"""Set the state of the callback.
|
| 347 |
+
|
| 348 |
+
This method should be implemented by subclasses to restore the callback's
|
| 349 |
+
state based on the given dict state.
|
| 350 |
+
|
| 351 |
+
This is used automatically by Tune to restore checkpoint callback state
|
| 352 |
+
on :ref:`Tune experiment restoration <tune-experiment-level-fault-tolerance>`.
|
| 353 |
+
|
| 354 |
+
See :meth:`~ray.tune.Callback.get_state` for an example implementation.
|
| 355 |
+
|
| 356 |
+
Args:
|
| 357 |
+
state: State of the callback.
|
| 358 |
+
"""
|
| 359 |
+
pass
|
| 360 |
+
|
| 361 |
+
|
| 362 |
+
@DeveloperAPI
|
| 363 |
+
class CallbackList(Callback):
|
| 364 |
+
"""Call multiple callbacks at once."""
|
| 365 |
+
|
| 366 |
+
IS_CALLBACK_CONTAINER = True
|
| 367 |
+
CKPT_FILE_TMPL = "callback-states-{}.pkl"
|
| 368 |
+
|
| 369 |
+
def __init__(self, callbacks: List[Callback]):
|
| 370 |
+
self._callbacks = callbacks
|
| 371 |
+
|
| 372 |
+
def setup(self, **info):
|
| 373 |
+
for callback in self._callbacks:
|
| 374 |
+
try:
|
| 375 |
+
callback.setup(**info)
|
| 376 |
+
except TypeError as e:
|
| 377 |
+
if "argument" in str(e):
|
| 378 |
+
warnings.warn(
|
| 379 |
+
"Please update `setup` method in callback "
|
| 380 |
+
f"`{callback.__class__}` to match the method signature"
|
| 381 |
+
" in `ray.tune.callback.Callback`.",
|
| 382 |
+
FutureWarning,
|
| 383 |
+
)
|
| 384 |
+
callback.setup()
|
| 385 |
+
else:
|
| 386 |
+
raise e
|
| 387 |
+
|
| 388 |
+
def on_step_begin(self, **info):
|
| 389 |
+
for callback in self._callbacks:
|
| 390 |
+
callback.on_step_begin(**info)
|
| 391 |
+
|
| 392 |
+
def on_step_end(self, **info):
|
| 393 |
+
for callback in self._callbacks:
|
| 394 |
+
callback.on_step_end(**info)
|
| 395 |
+
|
| 396 |
+
def on_trial_start(self, **info):
|
| 397 |
+
for callback in self._callbacks:
|
| 398 |
+
callback.on_trial_start(**info)
|
| 399 |
+
|
| 400 |
+
def on_trial_restore(self, **info):
|
| 401 |
+
for callback in self._callbacks:
|
| 402 |
+
callback.on_trial_restore(**info)
|
| 403 |
+
|
| 404 |
+
def on_trial_save(self, **info):
|
| 405 |
+
for callback in self._callbacks:
|
| 406 |
+
callback.on_trial_save(**info)
|
| 407 |
+
|
| 408 |
+
def on_trial_result(self, **info):
|
| 409 |
+
for callback in self._callbacks:
|
| 410 |
+
callback.on_trial_result(**info)
|
| 411 |
+
|
| 412 |
+
def on_trial_complete(self, **info):
|
| 413 |
+
for callback in self._callbacks:
|
| 414 |
+
callback.on_trial_complete(**info)
|
| 415 |
+
|
| 416 |
+
def on_trial_recover(self, **info):
|
| 417 |
+
for callback in self._callbacks:
|
| 418 |
+
callback.on_trial_recover(**info)
|
| 419 |
+
|
| 420 |
+
def on_trial_error(self, **info):
|
| 421 |
+
for callback in self._callbacks:
|
| 422 |
+
callback.on_trial_error(**info)
|
| 423 |
+
|
| 424 |
+
def on_checkpoint(self, **info):
|
| 425 |
+
for callback in self._callbacks:
|
| 426 |
+
callback.on_checkpoint(**info)
|
| 427 |
+
|
| 428 |
+
def on_experiment_end(self, **info):
|
| 429 |
+
for callback in self._callbacks:
|
| 430 |
+
callback.on_experiment_end(**info)
|
| 431 |
+
|
| 432 |
+
def get_state(self) -> Optional[Dict]:
|
| 433 |
+
"""Gets the state of all callbacks contained within this list.
|
| 434 |
+
If there are no stateful callbacks, then None will be returned in order
|
| 435 |
+
to avoid saving an unnecessary callback checkpoint file."""
|
| 436 |
+
state = {}
|
| 437 |
+
any_stateful_callbacks = False
|
| 438 |
+
for i, callback in enumerate(self._callbacks):
|
| 439 |
+
callback_state = callback.get_state()
|
| 440 |
+
if callback_state:
|
| 441 |
+
any_stateful_callbacks = True
|
| 442 |
+
state[i] = callback_state
|
| 443 |
+
if not any_stateful_callbacks:
|
| 444 |
+
return None
|
| 445 |
+
return state
|
| 446 |
+
|
| 447 |
+
def set_state(self, state: Dict):
|
| 448 |
+
"""Sets the state for all callbacks contained within this list.
|
| 449 |
+
Skips setting state for all stateless callbacks where `get_state`
|
| 450 |
+
returned None."""
|
| 451 |
+
for i, callback in enumerate(self._callbacks):
|
| 452 |
+
callback_state = state.get(i, None)
|
| 453 |
+
if callback_state:
|
| 454 |
+
callback.set_state(callback_state)
|
| 455 |
+
|
| 456 |
+
def save_to_dir(self, checkpoint_dir: str, session_str: str = "default"):
|
| 457 |
+
"""Save the state of the callback list to the checkpoint_dir.
|
| 458 |
+
|
| 459 |
+
Args:
|
| 460 |
+
checkpoint_dir: directory where the checkpoint is stored.
|
| 461 |
+
session_str: Unique identifier of the current run session (ex: timestamp).
|
| 462 |
+
"""
|
| 463 |
+
state_dict = self.get_state()
|
| 464 |
+
|
| 465 |
+
if state_dict:
|
| 466 |
+
file_name = self.CKPT_FILE_TMPL.format(session_str)
|
| 467 |
+
tmp_file_name = f".tmp-{file_name}"
|
| 468 |
+
_atomic_save(
|
| 469 |
+
state=state_dict,
|
| 470 |
+
checkpoint_dir=checkpoint_dir,
|
| 471 |
+
file_name=file_name,
|
| 472 |
+
tmp_file_name=tmp_file_name,
|
| 473 |
+
)
|
| 474 |
+
|
| 475 |
+
def restore_from_dir(self, checkpoint_dir: str):
|
| 476 |
+
"""Restore the state of the list of callbacks from the checkpoint_dir.
|
| 477 |
+
|
| 478 |
+
You should check if it's possible to restore with `can_restore`
|
| 479 |
+
before calling this method.
|
| 480 |
+
|
| 481 |
+
Args:
|
| 482 |
+
checkpoint_dir: directory where the checkpoint is stored.
|
| 483 |
+
|
| 484 |
+
Raises:
|
| 485 |
+
RuntimeError: if unable to find checkpoint.
|
| 486 |
+
NotImplementedError: if the `set_state` method is not implemented.
|
| 487 |
+
"""
|
| 488 |
+
state_dict = _load_newest_checkpoint(
|
| 489 |
+
checkpoint_dir, self.CKPT_FILE_TMPL.format("*")
|
| 490 |
+
)
|
| 491 |
+
if not state_dict:
|
| 492 |
+
raise RuntimeError(
|
| 493 |
+
"Unable to find checkpoint in {}.".format(checkpoint_dir)
|
| 494 |
+
)
|
| 495 |
+
self.set_state(state_dict)
|
| 496 |
+
|
| 497 |
+
def can_restore(self, checkpoint_dir: str) -> bool:
|
| 498 |
+
"""Check if the checkpoint_dir contains the saved state for this callback list.
|
| 499 |
+
|
| 500 |
+
Returns:
|
| 501 |
+
can_restore: True if the checkpoint_dir contains a file of the
|
| 502 |
+
format `CKPT_FILE_TMPL`. False otherwise.
|
| 503 |
+
"""
|
| 504 |
+
return any(
|
| 505 |
+
glob.iglob(Path(checkpoint_dir, self.CKPT_FILE_TMPL.format("*")).as_posix())
|
| 506 |
+
)
|
| 507 |
+
|
| 508 |
+
def __len__(self) -> int:
|
| 509 |
+
return len(self._callbacks)
|
| 510 |
+
|
| 511 |
+
def __getitem__(self, i: int) -> "Callback":
|
| 512 |
+
return self._callbacks[i]
|
.venv/lib/python3.11/site-packages/ray/tune/constants.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ==================================================
|
| 2 |
+
# Environment Variables
|
| 3 |
+
# ==================================================
|
| 4 |
+
|
| 5 |
+
# NOTE: When adding a new environment variable, please track it in this list.
|
| 6 |
+
TUNE_ENV_VARS = {
|
| 7 |
+
"RAY_AIR_LOCAL_CACHE_DIR",
|
| 8 |
+
"TUNE_DISABLE_AUTO_CALLBACK_LOGGERS",
|
| 9 |
+
"TUNE_DISABLE_AUTO_INIT",
|
| 10 |
+
"TUNE_DISABLE_DATED_SUBDIR",
|
| 11 |
+
"TUNE_DISABLE_STRICT_METRIC_CHECKING",
|
| 12 |
+
"TUNE_DISABLE_SIGINT_HANDLER",
|
| 13 |
+
"TUNE_FORCE_TRIAL_CLEANUP_S",
|
| 14 |
+
"TUNE_FUNCTION_THREAD_TIMEOUT_S",
|
| 15 |
+
"TUNE_GLOBAL_CHECKPOINT_S",
|
| 16 |
+
"TUNE_MAX_LEN_IDENTIFIER",
|
| 17 |
+
"TUNE_MAX_PENDING_TRIALS_PG",
|
| 18 |
+
"TUNE_PLACEMENT_GROUP_PREFIX",
|
| 19 |
+
"TUNE_PLACEMENT_GROUP_RECON_INTERVAL",
|
| 20 |
+
"TUNE_PRINT_ALL_TRIAL_ERRORS",
|
| 21 |
+
"TUNE_RESULT_DIR",
|
| 22 |
+
"TUNE_RESULT_BUFFER_LENGTH",
|
| 23 |
+
"TUNE_RESULT_DELIM",
|
| 24 |
+
"TUNE_RESULT_BUFFER_MAX_TIME_S",
|
| 25 |
+
"TUNE_RESULT_BUFFER_MIN_TIME_S",
|
| 26 |
+
"TUNE_WARN_THRESHOLD_S",
|
| 27 |
+
"TUNE_WARN_INSUFFICENT_RESOURCE_THRESHOLD_S",
|
| 28 |
+
"TUNE_WARN_INSUFFICENT_RESOURCE_THRESHOLD_S_AUTOSCALER",
|
| 29 |
+
"TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S",
|
| 30 |
+
"TUNE_STATE_REFRESH_PERIOD",
|
| 31 |
+
"TUNE_RESTORE_RETRY_NUM",
|
| 32 |
+
}
|
.venv/lib/python3.11/site-packages/ray/tune/context.py
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import threading
|
| 2 |
+
from typing import Any, Dict, Optional
|
| 3 |
+
|
| 4 |
+
from ray.train._internal import session
|
| 5 |
+
from ray.train.constants import _v2_migration_warnings_enabled
|
| 6 |
+
from ray.train.context import TrainContext as TrainV1Context
|
| 7 |
+
from ray.train.utils import _copy_doc
|
| 8 |
+
from ray.tune.execution.placement_groups import PlacementGroupFactory
|
| 9 |
+
from ray.util.annotations import Deprecated, PublicAPI
|
| 10 |
+
|
| 11 |
+
# The context singleton on this process.
|
| 12 |
+
_tune_context: Optional["TuneContext"] = None
|
| 13 |
+
_tune_context_lock = threading.Lock()
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
_TRAIN_SPECIFIC_CONTEXT_DEPRECATION_MESSAGE = (
|
| 17 |
+
"`{}` is deprecated for Ray Tune because there is no concept of worker ranks "
|
| 18 |
+
"for Ray Tune, so these methods only make sense to use in the context of "
|
| 19 |
+
"a Ray Train worker."
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@PublicAPI(stability="beta")
|
| 24 |
+
class TuneContext(TrainV1Context):
|
| 25 |
+
"""Context to access metadata within Ray Tune functions."""
|
| 26 |
+
|
| 27 |
+
# NOTE: These methods are deprecated on the TrainContext, but are still
|
| 28 |
+
# available on the TuneContext. Re-defining them here to avoid the
|
| 29 |
+
# deprecation warnings.
|
| 30 |
+
|
| 31 |
+
@_copy_doc(session.get_trial_name)
|
| 32 |
+
def get_trial_name(self) -> str:
|
| 33 |
+
return session.get_trial_name()
|
| 34 |
+
|
| 35 |
+
@_copy_doc(session.get_trial_id)
|
| 36 |
+
def get_trial_id(self) -> str:
|
| 37 |
+
return session.get_trial_id()
|
| 38 |
+
|
| 39 |
+
@_copy_doc(session.get_trial_resources)
|
| 40 |
+
def get_trial_resources(self) -> PlacementGroupFactory:
|
| 41 |
+
return session.get_trial_resources()
|
| 42 |
+
|
| 43 |
+
@_copy_doc(session.get_trial_dir)
|
| 44 |
+
def get_trial_dir(self) -> str:
|
| 45 |
+
return session.get_trial_dir()
|
| 46 |
+
|
| 47 |
+
# Deprecated APIs
|
| 48 |
+
|
| 49 |
+
@Deprecated
|
| 50 |
+
def get_metadata(self) -> Dict[str, Any]:
|
| 51 |
+
raise DeprecationWarning(
|
| 52 |
+
"`get_metadata` is deprecated for Ray Tune, as it has never been usable."
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
@Deprecated(
|
| 56 |
+
message=_TRAIN_SPECIFIC_CONTEXT_DEPRECATION_MESSAGE.format("get_world_size"),
|
| 57 |
+
warning=_v2_migration_warnings_enabled(),
|
| 58 |
+
)
|
| 59 |
+
@_copy_doc(TrainV1Context.get_world_size)
|
| 60 |
+
def get_world_size(self) -> int:
|
| 61 |
+
return session.get_world_size()
|
| 62 |
+
|
| 63 |
+
@Deprecated(
|
| 64 |
+
message=_TRAIN_SPECIFIC_CONTEXT_DEPRECATION_MESSAGE.format("get_world_rank"),
|
| 65 |
+
warning=_v2_migration_warnings_enabled(),
|
| 66 |
+
)
|
| 67 |
+
@_copy_doc(TrainV1Context.get_world_rank)
|
| 68 |
+
def get_world_rank(self) -> int:
|
| 69 |
+
return session.get_world_rank()
|
| 70 |
+
|
| 71 |
+
@Deprecated(
|
| 72 |
+
message=_TRAIN_SPECIFIC_CONTEXT_DEPRECATION_MESSAGE.format("get_local_rank"),
|
| 73 |
+
warning=_v2_migration_warnings_enabled(),
|
| 74 |
+
)
|
| 75 |
+
@_copy_doc(TrainV1Context.get_local_rank)
|
| 76 |
+
def get_local_rank(self) -> int:
|
| 77 |
+
return session.get_local_rank()
|
| 78 |
+
|
| 79 |
+
@Deprecated(
|
| 80 |
+
message=_TRAIN_SPECIFIC_CONTEXT_DEPRECATION_MESSAGE.format(
|
| 81 |
+
"get_local_world_size"
|
| 82 |
+
),
|
| 83 |
+
warning=_v2_migration_warnings_enabled(),
|
| 84 |
+
)
|
| 85 |
+
@_copy_doc(TrainV1Context.get_local_world_size)
|
| 86 |
+
def get_local_world_size(self) -> int:
|
| 87 |
+
return session.get_local_world_size()
|
| 88 |
+
|
| 89 |
+
@Deprecated(
|
| 90 |
+
message=_TRAIN_SPECIFIC_CONTEXT_DEPRECATION_MESSAGE.format("get_node_rank"),
|
| 91 |
+
warning=_v2_migration_warnings_enabled(),
|
| 92 |
+
)
|
| 93 |
+
@_copy_doc(TrainV1Context.get_node_rank)
|
| 94 |
+
def get_node_rank(self) -> int:
|
| 95 |
+
return session.get_node_rank()
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
@PublicAPI(stability="beta")
|
| 99 |
+
def get_context() -> TuneContext:
|
| 100 |
+
"""Get or create a singleton Ray Tune context.
|
| 101 |
+
|
| 102 |
+
The context is only available in a tune function passed to the `ray.tune.Tuner`.
|
| 103 |
+
|
| 104 |
+
See the :class:`~ray.tune.TuneContext` API reference to see available methods.
|
| 105 |
+
"""
|
| 106 |
+
global _tune_context
|
| 107 |
+
|
| 108 |
+
with _tune_context_lock:
|
| 109 |
+
if _tune_context is None:
|
| 110 |
+
# TODO(justinvyu): This default should be a dummy context
|
| 111 |
+
# that is only used for testing / running outside of Tune.
|
| 112 |
+
_tune_context = TuneContext()
|
| 113 |
+
return _tune_context
|
.venv/lib/python3.11/site-packages/ray/tune/error.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ray.util.annotations import PublicAPI
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
@PublicAPI
|
| 5 |
+
class TuneError(Exception):
|
| 6 |
+
"""General error class raised by ray.tune."""
|
| 7 |
+
|
| 8 |
+
pass
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class _AbortTrialExecution(TuneError):
|
| 12 |
+
"""Error that indicates a trial should not be retried."""
|
| 13 |
+
|
| 14 |
+
pass
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class _SubCategoryTuneError(TuneError):
|
| 18 |
+
"""The more specific TuneError that happens for a certain Tune
|
| 19 |
+
subroutine. For example starting/stopping a trial.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
def __init__(self, traceback_str: str):
|
| 23 |
+
self.traceback_str = traceback_str
|
| 24 |
+
|
| 25 |
+
def __str__(self):
|
| 26 |
+
return self.traceback_str
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class _TuneStopTrialError(_SubCategoryTuneError):
|
| 30 |
+
"""Error that happens when stopping a tune trial."""
|
| 31 |
+
|
| 32 |
+
pass
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class _TuneStartTrialError(_SubCategoryTuneError):
|
| 36 |
+
"""Error that happens when starting a tune trial."""
|
| 37 |
+
|
| 38 |
+
pass
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class _TuneNoNextExecutorEventError(_SubCategoryTuneError):
|
| 42 |
+
"""Error that happens when waiting to get the next event to
|
| 43 |
+
handle from RayTrialExecutor.
|
| 44 |
+
|
| 45 |
+
Note: RayTaskError will be raised by itself and will not be using
|
| 46 |
+
this category. This category is for everything else."""
|
| 47 |
+
|
| 48 |
+
pass
|
.venv/lib/python3.11/site-packages/ray/tune/examples/cifar10_pytorch.py
ADDED
|
@@ -0,0 +1,285 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ruff: noqa
|
| 2 |
+
# fmt: off
|
| 3 |
+
|
| 4 |
+
# __import_begin__
|
| 5 |
+
import os
|
| 6 |
+
import tempfile
|
| 7 |
+
from typing import Dict
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
import torch.optim as optim
|
| 14 |
+
import torchvision
|
| 15 |
+
import torchvision.transforms as transforms
|
| 16 |
+
from filelock import FileLock
|
| 17 |
+
from torch.utils.data import random_split
|
| 18 |
+
|
| 19 |
+
import ray
|
| 20 |
+
from ray import train, tune
|
| 21 |
+
from ray.train import Checkpoint
|
| 22 |
+
from ray.tune.schedulers import ASHAScheduler
|
| 23 |
+
|
| 24 |
+
# __import_end__
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
# __load_data_begin__
|
| 28 |
+
DATA_DIR = tempfile.mkdtemp()
|
| 29 |
+
|
| 30 |
+
def load_data(data_dir):
|
| 31 |
+
transform = transforms.Compose([
|
| 32 |
+
transforms.ToTensor(),
|
| 33 |
+
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
|
| 34 |
+
])
|
| 35 |
+
|
| 36 |
+
# We add FileLock here because multiple workers will want to
|
| 37 |
+
# download data, and this may cause overwrites since
|
| 38 |
+
# DataLoader is not threadsafe.
|
| 39 |
+
with FileLock(os.path.expanduser("~/.data.lock")):
|
| 40 |
+
trainset = torchvision.datasets.CIFAR10(
|
| 41 |
+
root=data_dir, train=True, download=True, transform=transform)
|
| 42 |
+
|
| 43 |
+
testset = torchvision.datasets.CIFAR10(
|
| 44 |
+
root=data_dir, train=False, download=True, transform=transform)
|
| 45 |
+
|
| 46 |
+
return trainset, testset
|
| 47 |
+
# __load_data_end__
|
| 48 |
+
|
| 49 |
+
def load_test_data():
|
| 50 |
+
# Loads a fake dataset for testing so it doesn't rely on external download.
|
| 51 |
+
trainset = torchvision.datasets.FakeData(
|
| 52 |
+
128, (3, 32, 32), num_classes=10, transform=transforms.ToTensor()
|
| 53 |
+
)
|
| 54 |
+
testset = torchvision.datasets.FakeData(
|
| 55 |
+
16, (3, 32, 32), num_classes=10, transform=transforms.ToTensor()
|
| 56 |
+
)
|
| 57 |
+
return trainset, testset
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
# __net_begin__
|
| 61 |
+
class Net(nn.Module):
|
| 62 |
+
def __init__(self, l1=120, l2=84):
|
| 63 |
+
super(Net, self).__init__()
|
| 64 |
+
self.conv1 = nn.Conv2d(3, 6, 5)
|
| 65 |
+
self.pool = nn.MaxPool2d(2, 2)
|
| 66 |
+
self.conv2 = nn.Conv2d(6, 16, 5)
|
| 67 |
+
self.fc1 = nn.Linear(16 * 5 * 5, l1)
|
| 68 |
+
self.fc2 = nn.Linear(l1, l2)
|
| 69 |
+
self.fc3 = nn.Linear(l2, 10)
|
| 70 |
+
|
| 71 |
+
def forward(self, x):
|
| 72 |
+
x = self.pool(F.relu(self.conv1(x)))
|
| 73 |
+
x = self.pool(F.relu(self.conv2(x)))
|
| 74 |
+
x = x.view(-1, 16 * 5 * 5)
|
| 75 |
+
x = F.relu(self.fc1(x))
|
| 76 |
+
x = F.relu(self.fc2(x))
|
| 77 |
+
x = self.fc3(x)
|
| 78 |
+
return x
|
| 79 |
+
# __net_end__
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
# __train_begin__
|
| 83 |
+
def train_cifar(config):
|
| 84 |
+
net = Net(config["l1"], config["l2"])
|
| 85 |
+
|
| 86 |
+
device = "cpu"
|
| 87 |
+
if torch.cuda.is_available():
|
| 88 |
+
device = "cuda:0"
|
| 89 |
+
if torch.cuda.device_count() > 1:
|
| 90 |
+
net = nn.DataParallel(net)
|
| 91 |
+
net.to(device)
|
| 92 |
+
|
| 93 |
+
criterion = nn.CrossEntropyLoss()
|
| 94 |
+
optimizer = optim.SGD(net.parameters(), lr=config["lr"], momentum=0.9)
|
| 95 |
+
|
| 96 |
+
# Load existing checkpoint through `get_checkpoint()` API.
|
| 97 |
+
if train.get_checkpoint():
|
| 98 |
+
loaded_checkpoint = train.get_checkpoint()
|
| 99 |
+
with loaded_checkpoint.as_directory() as loaded_checkpoint_dir:
|
| 100 |
+
model_state, optimizer_state = torch.load(
|
| 101 |
+
os.path.join(loaded_checkpoint_dir, "checkpoint.pt")
|
| 102 |
+
)
|
| 103 |
+
net.load_state_dict(model_state)
|
| 104 |
+
optimizer.load_state_dict(optimizer_state)
|
| 105 |
+
|
| 106 |
+
if config["smoke_test"]:
|
| 107 |
+
trainset, testset = load_test_data()
|
| 108 |
+
else:
|
| 109 |
+
trainset, testset = load_data(DATA_DIR)
|
| 110 |
+
|
| 111 |
+
test_abs = int(len(trainset) * 0.8)
|
| 112 |
+
train_subset, val_subset = random_split(
|
| 113 |
+
trainset, [test_abs, len(trainset) - test_abs])
|
| 114 |
+
|
| 115 |
+
trainloader = torch.utils.data.DataLoader(
|
| 116 |
+
train_subset,
|
| 117 |
+
batch_size=int(config["batch_size"]),
|
| 118 |
+
shuffle=True,
|
| 119 |
+
num_workers=0 if config["smoke_test"] else 8,
|
| 120 |
+
)
|
| 121 |
+
valloader = torch.utils.data.DataLoader(
|
| 122 |
+
val_subset,
|
| 123 |
+
batch_size=int(config["batch_size"]),
|
| 124 |
+
shuffle=True,
|
| 125 |
+
num_workers=0 if config["smoke_test"] else 8,
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
for epoch in range(10): # loop over the dataset multiple times
|
| 129 |
+
running_loss = 0.0
|
| 130 |
+
epoch_steps = 0
|
| 131 |
+
for i, data in enumerate(trainloader):
|
| 132 |
+
# get the inputs; data is a list of [inputs, labels]
|
| 133 |
+
inputs, labels = data
|
| 134 |
+
inputs, labels = inputs.to(device), labels.to(device)
|
| 135 |
+
|
| 136 |
+
# zero the parameter gradients
|
| 137 |
+
optimizer.zero_grad()
|
| 138 |
+
|
| 139 |
+
# forward + backward + optimize
|
| 140 |
+
outputs = net(inputs)
|
| 141 |
+
loss = criterion(outputs, labels)
|
| 142 |
+
loss.backward()
|
| 143 |
+
optimizer.step()
|
| 144 |
+
|
| 145 |
+
# print statistics
|
| 146 |
+
running_loss += loss.item()
|
| 147 |
+
epoch_steps += 1
|
| 148 |
+
if i % 2000 == 1999: # print every 2000 mini-batches
|
| 149 |
+
print("[%d, %5d] loss: %.3f" % (epoch + 1, i + 1,
|
| 150 |
+
running_loss / epoch_steps))
|
| 151 |
+
running_loss = 0.0
|
| 152 |
+
|
| 153 |
+
# Validation loss
|
| 154 |
+
val_loss = 0.0
|
| 155 |
+
val_steps = 0
|
| 156 |
+
total = 0
|
| 157 |
+
correct = 0
|
| 158 |
+
for i, data in enumerate(valloader, 0):
|
| 159 |
+
with torch.no_grad():
|
| 160 |
+
inputs, labels = data
|
| 161 |
+
inputs, labels = inputs.to(device), labels.to(device)
|
| 162 |
+
|
| 163 |
+
outputs = net(inputs)
|
| 164 |
+
_, predicted = torch.max(outputs.data, 1)
|
| 165 |
+
total += labels.size(0)
|
| 166 |
+
correct += (predicted == labels).sum().item()
|
| 167 |
+
|
| 168 |
+
loss = criterion(outputs, labels)
|
| 169 |
+
val_loss += loss.cpu().numpy()
|
| 170 |
+
val_steps += 1
|
| 171 |
+
|
| 172 |
+
# Here we save a checkpoint. It is automatically registered with
|
| 173 |
+
# Ray Tune and will potentially be accessed through in ``get_checkpoint()``
|
| 174 |
+
# in future iterations.
|
| 175 |
+
# Note to save a file like checkpoint, you still need to put it under a directory
|
| 176 |
+
# to construct a checkpoint.
|
| 177 |
+
with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
|
| 178 |
+
path = os.path.join(temp_checkpoint_dir, "checkpoint.pt")
|
| 179 |
+
torch.save(
|
| 180 |
+
(net.state_dict(), optimizer.state_dict()), path
|
| 181 |
+
)
|
| 182 |
+
checkpoint = Checkpoint.from_directory(temp_checkpoint_dir)
|
| 183 |
+
train.report(
|
| 184 |
+
{"loss": (val_loss / val_steps), "accuracy": correct / total},
|
| 185 |
+
checkpoint=checkpoint,
|
| 186 |
+
)
|
| 187 |
+
print("Finished Training")
|
| 188 |
+
# __train_end__
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
# __test_acc_begin__
|
| 192 |
+
def test_best_model(config: Dict, checkpoint: "Checkpoint", smoke_test=False):
|
| 193 |
+
best_trained_model = Net(config["l1"], config["l2"])
|
| 194 |
+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
| 195 |
+
best_trained_model.to(device)
|
| 196 |
+
|
| 197 |
+
with checkpoint.as_directory() as checkpoint_dir:
|
| 198 |
+
checkpoint_path = os.path.join(checkpoint_dir, "checkpoint.pt")
|
| 199 |
+
model_state, optimizer_state = torch.load(checkpoint_path)
|
| 200 |
+
best_trained_model.load_state_dict(model_state)
|
| 201 |
+
|
| 202 |
+
if smoke_test:
|
| 203 |
+
_, testset = load_test_data()
|
| 204 |
+
else:
|
| 205 |
+
_, testset = load_data(DATA_DIR)
|
| 206 |
+
|
| 207 |
+
testloader = torch.utils.data.DataLoader(
|
| 208 |
+
testset, batch_size=4, shuffle=False, num_workers=2)
|
| 209 |
+
|
| 210 |
+
correct = 0
|
| 211 |
+
total = 0
|
| 212 |
+
with torch.no_grad():
|
| 213 |
+
for data in testloader:
|
| 214 |
+
images, labels = data
|
| 215 |
+
images, labels = images.to(device), labels.to(device)
|
| 216 |
+
outputs = best_trained_model(images)
|
| 217 |
+
_, predicted = torch.max(outputs.data, 1)
|
| 218 |
+
total += labels.size(0)
|
| 219 |
+
correct += (predicted == labels).sum().item()
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
print("Best trial test set accuracy: {}".format(correct / total))
|
| 223 |
+
|
| 224 |
+
# __test_acc_end__
|
| 225 |
+
|
| 226 |
+
# __main_begin__
|
| 227 |
+
def main(num_samples=10, max_num_epochs=10, gpus_per_trial=2, smoke_test=False):
|
| 228 |
+
config = {
|
| 229 |
+
"l1": tune.sample_from(lambda _: 2 ** np.random.randint(2, 9)),
|
| 230 |
+
"l2": tune.sample_from(lambda _: 2 ** np.random.randint(2, 9)),
|
| 231 |
+
"lr": tune.loguniform(1e-4, 1e-1),
|
| 232 |
+
"batch_size": tune.choice([2, 4, 8, 16]),
|
| 233 |
+
"smoke_test": smoke_test,
|
| 234 |
+
}
|
| 235 |
+
scheduler = ASHAScheduler(
|
| 236 |
+
max_t=max_num_epochs,
|
| 237 |
+
grace_period=1,
|
| 238 |
+
reduction_factor=2)
|
| 239 |
+
|
| 240 |
+
tuner = tune.Tuner(
|
| 241 |
+
tune.with_resources(
|
| 242 |
+
tune.with_parameters(train_cifar),
|
| 243 |
+
resources={"cpu": 2, "gpu": gpus_per_trial},
|
| 244 |
+
),
|
| 245 |
+
tune_config=tune.TuneConfig(
|
| 246 |
+
metric="loss",
|
| 247 |
+
mode="min",
|
| 248 |
+
num_samples=num_samples,
|
| 249 |
+
scheduler=scheduler
|
| 250 |
+
),
|
| 251 |
+
param_space=config,
|
| 252 |
+
)
|
| 253 |
+
results = tuner.fit()
|
| 254 |
+
best_result = results.get_best_result("loss", "min")
|
| 255 |
+
print("Best trial config: {}".format(best_result.config))
|
| 256 |
+
print("Best trial final validation loss: {}".format(
|
| 257 |
+
best_result.metrics["loss"]))
|
| 258 |
+
print("Best trial final validation accuracy: {}".format(
|
| 259 |
+
best_result.metrics["accuracy"]))
|
| 260 |
+
|
| 261 |
+
test_best_model(best_result.config, best_result.checkpoint, smoke_test=smoke_test)
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
# __main_end__
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
if __name__ == "__main__":
|
| 268 |
+
import argparse
|
| 269 |
+
|
| 270 |
+
parser = argparse.ArgumentParser()
|
| 271 |
+
parser.add_argument(
|
| 272 |
+
"--smoke-test", action="store_true", help="Finish quickly for testing")
|
| 273 |
+
parser.add_argument(
|
| 274 |
+
"--ray-address",
|
| 275 |
+
help="Address of Ray cluster for seamless distributed execution.",
|
| 276 |
+
required=False)
|
| 277 |
+
args, _ = parser.parse_known_args()
|
| 278 |
+
|
| 279 |
+
if args.smoke_test:
|
| 280 |
+
ray.init(num_cpus=2)
|
| 281 |
+
main(num_samples=1, max_num_epochs=1, gpus_per_trial=0, smoke_test=True)
|
| 282 |
+
else:
|
| 283 |
+
ray.init(args.ray_address)
|
| 284 |
+
# Change this to activate training on GPUs
|
| 285 |
+
main(num_samples=10, max_num_epochs=10, gpus_per_trial=0)
|
.venv/lib/python3.11/site-packages/ray/tune/examples/lightgbm_example.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import lightgbm as lgb
|
| 2 |
+
import sklearn.datasets
|
| 3 |
+
import sklearn.metrics
|
| 4 |
+
from sklearn.model_selection import train_test_split
|
| 5 |
+
|
| 6 |
+
from ray import tune
|
| 7 |
+
from ray.tune.integration.lightgbm import TuneReportCheckpointCallback
|
| 8 |
+
from ray.tune.schedulers import ASHAScheduler
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def train_breast_cancer(config: dict):
|
| 12 |
+
# This is a simple training function to be passed into Tune
|
| 13 |
+
|
| 14 |
+
# Load dataset
|
| 15 |
+
data, target = sklearn.datasets.load_breast_cancer(return_X_y=True)
|
| 16 |
+
|
| 17 |
+
# Split into train and test set
|
| 18 |
+
train_x, test_x, train_y, test_y = train_test_split(data, target, test_size=0.25)
|
| 19 |
+
|
| 20 |
+
# Build input Datasets for LightGBM
|
| 21 |
+
train_set = lgb.Dataset(train_x, label=train_y)
|
| 22 |
+
test_set = lgb.Dataset(test_x, label=test_y)
|
| 23 |
+
|
| 24 |
+
# Train the classifier, using the Tune callback
|
| 25 |
+
lgb.train(
|
| 26 |
+
config,
|
| 27 |
+
train_set,
|
| 28 |
+
valid_sets=[test_set],
|
| 29 |
+
valid_names=["eval"],
|
| 30 |
+
verbose_eval=False,
|
| 31 |
+
callbacks=[
|
| 32 |
+
TuneReportCheckpointCallback(
|
| 33 |
+
{
|
| 34 |
+
"binary_error": "eval-binary_error",
|
| 35 |
+
"binary_logloss": "eval-binary_logloss",
|
| 36 |
+
}
|
| 37 |
+
)
|
| 38 |
+
],
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def train_breast_cancer_cv(config: dict):
|
| 43 |
+
# This is a simple training function to be passed into Tune, using
|
| 44 |
+
# lightgbm's cross validation functionality
|
| 45 |
+
|
| 46 |
+
# Load dataset
|
| 47 |
+
data, target = sklearn.datasets.load_breast_cancer(return_X_y=True)
|
| 48 |
+
|
| 49 |
+
train_set = lgb.Dataset(data, label=target)
|
| 50 |
+
|
| 51 |
+
# Run CV, using the Tune callback
|
| 52 |
+
lgb.cv(
|
| 53 |
+
config,
|
| 54 |
+
train_set,
|
| 55 |
+
verbose_eval=False,
|
| 56 |
+
stratified=True,
|
| 57 |
+
# Checkpointing is not supported for CV
|
| 58 |
+
# LightGBM aggregates metrics over folds automatically
|
| 59 |
+
# with the cv_agg key. Both mean and standard deviation
|
| 60 |
+
# are provided.
|
| 61 |
+
callbacks=[
|
| 62 |
+
TuneReportCheckpointCallback(
|
| 63 |
+
{
|
| 64 |
+
"binary_error": "cv_agg-binary_error-mean",
|
| 65 |
+
"binary_logloss": "cv_agg-binary_logloss-mean",
|
| 66 |
+
"binary_error_stdv": "cv_agg-binary_error-stdv",
|
| 67 |
+
"binary_logloss_stdv": "cv_agg-binary_logloss-stdv",
|
| 68 |
+
},
|
| 69 |
+
frequency=0,
|
| 70 |
+
)
|
| 71 |
+
],
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
if __name__ == "__main__":
|
| 76 |
+
import argparse
|
| 77 |
+
|
| 78 |
+
parser = argparse.ArgumentParser()
|
| 79 |
+
parser.add_argument(
|
| 80 |
+
"--use-cv", action="store_true", help="Use `lgb.cv` instead of `lgb.train`."
|
| 81 |
+
)
|
| 82 |
+
args, _ = parser.parse_known_args()
|
| 83 |
+
|
| 84 |
+
config = {
|
| 85 |
+
"objective": "binary",
|
| 86 |
+
"metric": ["binary_error", "binary_logloss"],
|
| 87 |
+
"verbose": -1,
|
| 88 |
+
"boosting_type": tune.grid_search(["gbdt", "dart"]),
|
| 89 |
+
"num_leaves": tune.randint(10, 1000),
|
| 90 |
+
"learning_rate": tune.loguniform(1e-8, 1e-1),
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
tuner = tune.Tuner(
|
| 94 |
+
train_breast_cancer if not args.use_cv else train_breast_cancer_cv,
|
| 95 |
+
tune_config=tune.TuneConfig(
|
| 96 |
+
metric="binary_error",
|
| 97 |
+
mode="min",
|
| 98 |
+
num_samples=2,
|
| 99 |
+
scheduler=ASHAScheduler(),
|
| 100 |
+
),
|
| 101 |
+
param_space=config,
|
| 102 |
+
)
|
| 103 |
+
results = tuner.fit()
|
| 104 |
+
|
| 105 |
+
print("Best hyperparameters found were: ", results.get_best_result().config)
|
.venv/lib/python3.11/site-packages/ray/tune/execution/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/ray/tune/execution/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (191 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/tune/execution/__pycache__/class_cache.cpython-311.pyc
ADDED
|
Binary file (2.78 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/tune/execution/__pycache__/cluster_info.cpython-311.pyc
ADDED
|
Binary file (824 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/tune/execution/__pycache__/experiment_state.cpython-311.pyc
ADDED
|
Binary file (13.2 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/tune/execution/__pycache__/insufficient_resources_manager.cpython-311.pyc
ADDED
|
Binary file (7.59 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/tune/execution/__pycache__/placement_groups.cpython-311.pyc
ADDED
|
Binary file (5.54 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/tune/execution/class_cache.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
import ray
|
| 4 |
+
from ray.air.constants import COPY_DIRECTORY_CHECKPOINTS_INSTEAD_OF_MOVING_ENV
|
| 5 |
+
from ray.train.constants import (
|
| 6 |
+
ENABLE_V2_MIGRATION_WARNINGS_ENV_VAR,
|
| 7 |
+
RAY_CHDIR_TO_TRIAL_DIR,
|
| 8 |
+
)
|
| 9 |
+
from ray.train.v2._internal.constants import (
|
| 10 |
+
ENV_VARS_TO_PROPAGATE as TRAIN_ENV_VARS_TO_PROPAGATE,
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
DEFAULT_ENV_VARS = {
|
| 14 |
+
# https://github.com/ray-project/ray/issues/28197
|
| 15 |
+
"PL_DISABLE_FORK": "1"
|
| 16 |
+
}
|
| 17 |
+
ENV_VARS_TO_PROPAGATE = (
|
| 18 |
+
{
|
| 19 |
+
COPY_DIRECTORY_CHECKPOINTS_INSTEAD_OF_MOVING_ENV,
|
| 20 |
+
RAY_CHDIR_TO_TRIAL_DIR,
|
| 21 |
+
ENABLE_V2_MIGRATION_WARNINGS_ENV_VAR,
|
| 22 |
+
"AWS_ACCESS_KEY_ID",
|
| 23 |
+
"AWS_SECRET_ACCESS_KEY",
|
| 24 |
+
"AWS_SECURITY_TOKEN",
|
| 25 |
+
"AWS_SESSION_TOKEN",
|
| 26 |
+
}
|
| 27 |
+
# Propagate the Ray Train environment variables from the driver process
|
| 28 |
+
# to the trainable process so that Tune + Train v2 can be used together.
|
| 29 |
+
| TRAIN_ENV_VARS_TO_PROPAGATE
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class _ActorClassCache:
|
| 34 |
+
"""Caches actor classes.
|
| 35 |
+
|
| 36 |
+
ray.remote is a registration call. It sends the serialized object to the
|
| 37 |
+
key value store (redis), and will be fetched at an arbitrary worker
|
| 38 |
+
later. Registration does not use any Ray scheduling resources.
|
| 39 |
+
|
| 40 |
+
Later, class.remote() actually creates the remote actor. The
|
| 41 |
+
actor will be instantiated on some arbitrary machine,
|
| 42 |
+
according to the underlying Ray scheduler.
|
| 43 |
+
|
| 44 |
+
Without this cache, you would register the same serialized object
|
| 45 |
+
over and over again. Naturally, since redis doesn’t spill to disk,
|
| 46 |
+
this can easily nuke the redis instance (and basically blow up Ray).
|
| 47 |
+
This cache instead allows us to register once and only once.
|
| 48 |
+
|
| 49 |
+
Note that we assume there can be multiple trainables in the
|
| 50 |
+
system at once.
|
| 51 |
+
"""
|
| 52 |
+
|
| 53 |
+
def __init__(self):
|
| 54 |
+
self._cache = {}
|
| 55 |
+
|
| 56 |
+
def get(self, trainable_cls):
|
| 57 |
+
"""Gets the wrapped trainable_cls, otherwise calls ray.remote."""
|
| 58 |
+
env_vars = DEFAULT_ENV_VARS.copy()
|
| 59 |
+
|
| 60 |
+
for env_var_to_propagate in ENV_VARS_TO_PROPAGATE:
|
| 61 |
+
if env_var_to_propagate in os.environ:
|
| 62 |
+
env_vars[env_var_to_propagate] = os.environ[env_var_to_propagate]
|
| 63 |
+
|
| 64 |
+
runtime_env = {"env_vars": env_vars}
|
| 65 |
+
if trainable_cls not in self._cache:
|
| 66 |
+
remote_cls = ray.remote(runtime_env=runtime_env)(trainable_cls)
|
| 67 |
+
self._cache[trainable_cls] = remote_cls
|
| 68 |
+
return self._cache[trainable_cls]
|
.venv/lib/python3.11/site-packages/ray/tune/execution/cluster_info.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from functools import lru_cache
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
@lru_cache()
|
| 6 |
+
def _is_ray_cluster():
|
| 7 |
+
"""Checks if the bootstrap config file exists.
|
| 8 |
+
|
| 9 |
+
This will always exist if using an autoscaling cluster/started
|
| 10 |
+
with the ray cluster launcher.
|
| 11 |
+
"""
|
| 12 |
+
return Path("~/ray_bootstrap_config.yaml").expanduser().exists()
|
.venv/lib/python3.11/site-packages/ray/tune/execution/experiment_state.py
ADDED
|
@@ -0,0 +1,287 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import fnmatch
|
| 2 |
+
import logging
|
| 3 |
+
import os
|
| 4 |
+
import time
|
| 5 |
+
from collections import Counter
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import Callable, Dict, Optional, Union
|
| 8 |
+
|
| 9 |
+
import pyarrow.fs
|
| 10 |
+
|
| 11 |
+
from ray.train._internal.storage import (
|
| 12 |
+
StorageContext,
|
| 13 |
+
_download_from_fs_path,
|
| 14 |
+
_list_at_fs_path,
|
| 15 |
+
get_fs_and_path,
|
| 16 |
+
)
|
| 17 |
+
from ray.tune.experiment.trial import Trial
|
| 18 |
+
from ray.tune.impl.out_of_band_serialize_dataset import out_of_band_serialize_dataset
|
| 19 |
+
|
| 20 |
+
logger = logging.getLogger(__name__)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
_SLOW_SYNC_WARNING = (
|
| 24 |
+
"This could be due to a large number of trials, "
|
| 25 |
+
"large logfiles from lots of reported metrics, or throttling from the "
|
| 26 |
+
"remote storage if uploading too frequently.\n"
|
| 27 |
+
"You may want to consider switching the `RunConfig(storage_filesystem)`"
|
| 28 |
+
" to a more performant storage backend such as s3fs for a "
|
| 29 |
+
"S3 storage path.\n"
|
| 30 |
+
"You can suppress this error by setting the environment variable "
|
| 31 |
+
"TUNE_WARN_SLOW_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S to a higher "
|
| 32 |
+
"value than the current threshold ({threshold})."
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def _find_newest_experiment_checkpoint(
|
| 37 |
+
experiment_path: str, fs: Optional[pyarrow.fs.FileSystem] = None
|
| 38 |
+
) -> Optional[str]:
|
| 39 |
+
"""Returns file name of most recently created experiment checkpoint.
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
experiment_path: Local or remote path to the experiment directory
|
| 43 |
+
containing at least one experiment checkpoint file.
|
| 44 |
+
|
| 45 |
+
Returns:
|
| 46 |
+
str: The local or remote path to the latest experiment checkpoint file
|
| 47 |
+
based on timestamp. None if no experiment checkpoints were found.
|
| 48 |
+
"""
|
| 49 |
+
from ray.tune.execution.tune_controller import TuneController
|
| 50 |
+
|
| 51 |
+
fs, experiment_fs_path = get_fs_and_path(experiment_path, storage_filesystem=fs)
|
| 52 |
+
filenames = _list_at_fs_path(fs=fs, fs_path=experiment_fs_path)
|
| 53 |
+
pattern = TuneController.CKPT_FILE_TMPL.format("*")
|
| 54 |
+
matching = fnmatch.filter(filenames, pattern)
|
| 55 |
+
if not matching:
|
| 56 |
+
return None
|
| 57 |
+
filename = max(matching)
|
| 58 |
+
return Path(experiment_fs_path, filename).as_posix()
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class _ExperimentCheckpointManager:
|
| 62 |
+
"""Helper class for managing experiment-level checkpoints.
|
| 63 |
+
|
| 64 |
+
This class implements the ``checkpoint()`` method used to checkpoint
|
| 65 |
+
experiment state. When called, this will serialize and write to disk
|
| 66 |
+
the state of the trial runner, trial executor, and search algorithm, to
|
| 67 |
+
a specified checkpoint file.
|
| 68 |
+
|
| 69 |
+
The checkpoint period is automatically adjusted to
|
| 70 |
+
``max(10, time_per_checkpoint * 19)``. This means that at most 5% of the
|
| 71 |
+
time (1/20) will be used for writing checkpoints, while 95% of the time
|
| 72 |
+
(19/20) will be used to handle the rest of the training loop.
|
| 73 |
+
"""
|
| 74 |
+
|
| 75 |
+
def __init__(
|
| 76 |
+
self,
|
| 77 |
+
*,
|
| 78 |
+
storage: Optional[StorageContext],
|
| 79 |
+
checkpoint_period: Union[int, float, str],
|
| 80 |
+
sync_every_n_trial_checkpoints: Optional[int] = None,
|
| 81 |
+
):
|
| 82 |
+
self._storage = storage
|
| 83 |
+
|
| 84 |
+
self._last_save_time = float("-inf")
|
| 85 |
+
self._last_sync_time = None
|
| 86 |
+
|
| 87 |
+
# Dynamic checkpointing period
|
| 88 |
+
self._auto_checkpoint_enabled = checkpoint_period == "auto"
|
| 89 |
+
if self._auto_checkpoint_enabled:
|
| 90 |
+
self._checkpoint_period = 10.0 # Initial value
|
| 91 |
+
else:
|
| 92 |
+
self._checkpoint_period = float(checkpoint_period)
|
| 93 |
+
|
| 94 |
+
# TODO(justinvyu): This is a non-performant workaround to force sync
|
| 95 |
+
# every num_to_keep checkpoints in order to maintain consistency
|
| 96 |
+
# between the experiment state's view of the latest checkpoint,
|
| 97 |
+
# and the actual latest checkpoint that was uploaded.
|
| 98 |
+
self._sync_every_n_trial_checkpoints = sync_every_n_trial_checkpoints
|
| 99 |
+
self._trial_num_checkpoints_since_last_sync: Dict[Trial, int] = Counter()
|
| 100 |
+
self._should_force_sync_up: bool = False
|
| 101 |
+
|
| 102 |
+
self._excessive_sync_threshold = float(
|
| 103 |
+
os.environ.get(
|
| 104 |
+
"TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S", "5"
|
| 105 |
+
)
|
| 106 |
+
)
|
| 107 |
+
self._slow_sync_threshold = float(
|
| 108 |
+
os.environ.get(
|
| 109 |
+
"TUNE_WARN_SLOW_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S", "30"
|
| 110 |
+
)
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
@property
|
| 114 |
+
def auto_checkpoint_enabled(self):
|
| 115 |
+
return self._auto_checkpoint_enabled
|
| 116 |
+
|
| 117 |
+
def _update_auto_checkpoint_time(self, time_taken: float):
|
| 118 |
+
if self._auto_checkpoint_enabled:
|
| 119 |
+
# Multiplying this time by 19 means we spend ~5% of the time
|
| 120 |
+
# writing global checkpoints and 95% of the time processing trials
|
| 121 |
+
self._checkpoint_period = max(10.0, time_taken * 19)
|
| 122 |
+
logger.debug(
|
| 123 |
+
f"Experiment state snapshotting took "
|
| 124 |
+
f"{time_taken:.2f} seconds. "
|
| 125 |
+
f"Adjusting snapshotting period to "
|
| 126 |
+
f"{self._checkpoint_period:.2f} seconds."
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
def sync_up_experiment_state(
|
| 130 |
+
self,
|
| 131 |
+
save_fn: Callable[[], None],
|
| 132 |
+
force: bool = False,
|
| 133 |
+
wait: bool = False,
|
| 134 |
+
):
|
| 135 |
+
"""Saves execution state to the experiment directory on the storage path.
|
| 136 |
+
This includes an experiment checkpoint file that contains trial statuses
|
| 137 |
+
and the searcher state.
|
| 138 |
+
|
| 139 |
+
Overwrites the current session checkpoint, which starts when self
|
| 140 |
+
is instantiated. Throttle depends on self._checkpoint_period.
|
| 141 |
+
|
| 142 |
+
Args:
|
| 143 |
+
save_fn: Function to call to actually save data to the driver
|
| 144 |
+
staging path. The files in the driver staging path will be
|
| 145 |
+
uploaded to the storage path.
|
| 146 |
+
force: Forces an experiment checkpoint and launches a sync to storage.
|
| 147 |
+
This happens regardless of checkpoint_period
|
| 148 |
+
wait: Waits for the sync up to complete before returning.
|
| 149 |
+
"""
|
| 150 |
+
driver_staging_path = self._storage.experiment_driver_staging_path
|
| 151 |
+
|
| 152 |
+
force = force or self._should_force_sync_up
|
| 153 |
+
|
| 154 |
+
now = time.monotonic()
|
| 155 |
+
if now - self._last_save_time < self._checkpoint_period and not force:
|
| 156 |
+
return
|
| 157 |
+
|
| 158 |
+
# Checkpoint
|
| 159 |
+
checkpoint_time_start = time.monotonic()
|
| 160 |
+
|
| 161 |
+
# NOTE: This context manager is for Datasets captured in a trial config.
|
| 162 |
+
# This is the case when *tuning over datasets*.
|
| 163 |
+
# If the datasets have already been full executed, then serializing
|
| 164 |
+
# block refs means that this checkpoint is not usable in a new Ray cluster.
|
| 165 |
+
# This context will serialize the dataset execution plan instead, if available.
|
| 166 |
+
with out_of_band_serialize_dataset():
|
| 167 |
+
save_fn()
|
| 168 |
+
|
| 169 |
+
def wait_for_sync():
|
| 170 |
+
try:
|
| 171 |
+
self._storage.syncer.wait()
|
| 172 |
+
except Exception:
|
| 173 |
+
logger.error(
|
| 174 |
+
"Saving experiment state to storage at "
|
| 175 |
+
f"'{self._storage.experiment_fs_path}' failed with exception: ",
|
| 176 |
+
exc_info=True,
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
if force:
|
| 180 |
+
start_time = time.monotonic()
|
| 181 |
+
wait_for_sync()
|
| 182 |
+
wait_time = time.monotonic() - start_time
|
| 183 |
+
if wait_time > self._slow_sync_threshold:
|
| 184 |
+
logger.warning(
|
| 185 |
+
"Saving the experiment state (which holds a global view "
|
| 186 |
+
"of trial statuses and is used to restore the experiment) "
|
| 187 |
+
f"took ~{wait_time:.2f} seconds, which may be a performance "
|
| 188 |
+
"bottleneck.\n"
|
| 189 |
+
f"{_SLOW_SYNC_WARNING.format(threshold=self._slow_sync_threshold)}"
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
time_since_last_sync = (
|
| 193 |
+
time.monotonic() - self._last_sync_time
|
| 194 |
+
if self._last_sync_time is not None
|
| 195 |
+
else None
|
| 196 |
+
)
|
| 197 |
+
launched_sync = self._storage.syncer.sync_up(
|
| 198 |
+
driver_staging_path, self._storage.experiment_fs_path
|
| 199 |
+
)
|
| 200 |
+
if launched_sync:
|
| 201 |
+
if (
|
| 202 |
+
time_since_last_sync is not None
|
| 203 |
+
and time_since_last_sync < self._excessive_sync_threshold
|
| 204 |
+
and self._should_force_sync_up
|
| 205 |
+
):
|
| 206 |
+
logger.warning(
|
| 207 |
+
"Experiment state snapshotting has been triggered multiple "
|
| 208 |
+
f"times in the last {self._excessive_sync_threshold} seconds "
|
| 209 |
+
"and may become a bottleneck. "
|
| 210 |
+
"A snapshot is forced if `CheckpointConfig(num_to_keep)` is set, "
|
| 211 |
+
"and a trial has checkpointed >= `num_to_keep` times "
|
| 212 |
+
"since the last snapshot.\n"
|
| 213 |
+
"You may want to consider increasing the "
|
| 214 |
+
"`CheckpointConfig(num_to_keep)` or decreasing the frequency of "
|
| 215 |
+
"saving checkpoints.\n"
|
| 216 |
+
"You can suppress this warning by setting the environment variable "
|
| 217 |
+
"TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S "
|
| 218 |
+
"to a smaller value than the current threshold "
|
| 219 |
+
f"({self._excessive_sync_threshold}). "
|
| 220 |
+
"Set it to 0 to completely suppress this warning."
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
self._last_sync_time = time.monotonic()
|
| 224 |
+
|
| 225 |
+
# We just synced, so reset the force flag
|
| 226 |
+
self._trial_num_checkpoints_since_last_sync.clear()
|
| 227 |
+
self._should_force_sync_up = False
|
| 228 |
+
else:
|
| 229 |
+
if (
|
| 230 |
+
time_since_last_sync is not None
|
| 231 |
+
and time_since_last_sync > self._slow_sync_threshold
|
| 232 |
+
):
|
| 233 |
+
logger.warning(
|
| 234 |
+
"Saving the experiment state (which holds a global view "
|
| 235 |
+
"of trial statuses and is used to restore the experiment) "
|
| 236 |
+
f"has already taken {time_since_last_sync:.2f} seconds, "
|
| 237 |
+
"which may cause consistency issues upon restoration if your "
|
| 238 |
+
"driver script ungracefully exits.\n"
|
| 239 |
+
f"{_SLOW_SYNC_WARNING.format(threshold=self._slow_sync_threshold)}"
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
if wait:
|
| 243 |
+
wait_for_sync()
|
| 244 |
+
|
| 245 |
+
checkpoint_time_taken = time.monotonic() - checkpoint_time_start
|
| 246 |
+
|
| 247 |
+
# Adjust dynamic checkpointing
|
| 248 |
+
self._update_auto_checkpoint_time(time_taken=checkpoint_time_taken)
|
| 249 |
+
|
| 250 |
+
# Finish
|
| 251 |
+
self._last_save_time = time.monotonic()
|
| 252 |
+
|
| 253 |
+
def sync_down_experiment_state(self) -> None:
|
| 254 |
+
fs = self._storage.storage_filesystem
|
| 255 |
+
filepaths = _list_at_fs_path(fs=fs, fs_path=self._storage.experiment_fs_path)
|
| 256 |
+
# TODO(ekl) we should refactor our restore code to read the necessary data
|
| 257 |
+
# directly from the storage context. As a temporary hack, restore all the
|
| 258 |
+
# serialized files from the root dir where other modules expect them to be.
|
| 259 |
+
matches = [
|
| 260 |
+
path
|
| 261 |
+
for path in filepaths
|
| 262 |
+
if path.endswith(".json") or path.endswith(".pkl")
|
| 263 |
+
]
|
| 264 |
+
for relpath in matches:
|
| 265 |
+
fs_path = Path(self._storage.experiment_fs_path, relpath).as_posix()
|
| 266 |
+
local_path = Path(
|
| 267 |
+
self._storage.experiment_driver_staging_path, relpath
|
| 268 |
+
).as_posix()
|
| 269 |
+
_download_from_fs_path(fs=fs, fs_path=fs_path, local_path=local_path)
|
| 270 |
+
logger.debug(
|
| 271 |
+
f"Copied {matches} from:\n(fs, path) = "
|
| 272 |
+
f"({self._storage.storage_filesystem.type_name}, "
|
| 273 |
+
f"{self._storage.experiment_fs_path})\n"
|
| 274 |
+
f"-> {self._storage.experiment_driver_staging_path}"
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
def on_trial_checkpoint(self, trial: Trial):
|
| 278 |
+
if not self._sync_every_n_trial_checkpoints:
|
| 279 |
+
return
|
| 280 |
+
|
| 281 |
+
self._trial_num_checkpoints_since_last_sync[trial] += 1
|
| 282 |
+
|
| 283 |
+
if (
|
| 284 |
+
self._trial_num_checkpoints_since_last_sync[trial]
|
| 285 |
+
>= self._sync_every_n_trial_checkpoints
|
| 286 |
+
):
|
| 287 |
+
self._should_force_sync_up = True
|
.venv/lib/python3.11/site-packages/ray/tune/execution/insufficient_resources_manager.py
ADDED
|
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import os
|
| 3 |
+
import time
|
| 4 |
+
from functools import lru_cache
|
| 5 |
+
from typing import Dict, Optional, Tuple
|
| 6 |
+
|
| 7 |
+
import ray
|
| 8 |
+
from ray.tune.execution.cluster_info import _is_ray_cluster
|
| 9 |
+
from ray.tune.experiment import Trial
|
| 10 |
+
|
| 11 |
+
logger = logging.getLogger(__name__)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
# Ideally we want to use @cache; but it's only available for python 3.9.
|
| 15 |
+
# Caching is only helpful/correct for no autoscaler case.
|
| 16 |
+
@lru_cache()
|
| 17 |
+
def _get_cluster_resources_no_autoscaler() -> Dict:
|
| 18 |
+
return ray.cluster_resources()
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def _get_trial_cpu_and_gpu(trial: Trial) -> Tuple[int, int]:
|
| 22 |
+
cpu = trial.placement_group_factory.required_resources.get("CPU", 0)
|
| 23 |
+
gpu = trial.placement_group_factory.required_resources.get("GPU", 0)
|
| 24 |
+
return cpu, gpu
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def _can_fulfill_no_autoscaler(trial: Trial) -> bool:
|
| 28 |
+
"""Calculates if there is enough resources for a PENDING trial.
|
| 29 |
+
|
| 30 |
+
For no autoscaler case.
|
| 31 |
+
"""
|
| 32 |
+
assert trial.status == Trial.PENDING
|
| 33 |
+
asked_cpus, asked_gpus = _get_trial_cpu_and_gpu(trial)
|
| 34 |
+
|
| 35 |
+
return asked_cpus <= _get_cluster_resources_no_autoscaler().get(
|
| 36 |
+
"CPU", 0
|
| 37 |
+
) and asked_gpus <= _get_cluster_resources_no_autoscaler().get("GPU", 0)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
@lru_cache()
|
| 41 |
+
def _get_insufficient_resources_warning_threshold() -> float:
|
| 42 |
+
if _is_ray_cluster():
|
| 43 |
+
return float(
|
| 44 |
+
os.environ.get(
|
| 45 |
+
"TUNE_WARN_INSUFFICENT_RESOURCE_THRESHOLD_S_AUTOSCALER", "60"
|
| 46 |
+
)
|
| 47 |
+
)
|
| 48 |
+
else:
|
| 49 |
+
# Set the default to 10s so that we don't prematurely determine that
|
| 50 |
+
# a cluster cannot fulfill the resources requirements.
|
| 51 |
+
# TODO(xwjiang): Change it back once #18608 is resolved.
|
| 52 |
+
return float(os.environ.get("TUNE_WARN_INSUFFICENT_RESOURCE_THRESHOLD_S", "60"))
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
MSG_TRAIN_START = (
|
| 56 |
+
"Training has not started in the last {wait_time:.0f} seconds. "
|
| 57 |
+
"This could be due to the cluster not having enough resources available. "
|
| 58 |
+
)
|
| 59 |
+
MSG_TRAIN_INSUFFICIENT = (
|
| 60 |
+
"You asked for {asked_cpus} CPUs and {asked_gpus} GPUs, but the cluster only "
|
| 61 |
+
"has {cluster_cpus} CPUs and {cluster_gpus} GPUs available. "
|
| 62 |
+
)
|
| 63 |
+
MSG_TRAIN_END = (
|
| 64 |
+
"Stop the training and adjust the required resources (e.g. via the "
|
| 65 |
+
"`ScalingConfig` or `resources_per_trial`, or `num_workers` for rllib), "
|
| 66 |
+
"or add more resources to your cluster."
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
MSG_TUNE_START = (
|
| 70 |
+
"No trial is running and no new trial has been started within "
|
| 71 |
+
"the last {wait_time:.0f} seconds. "
|
| 72 |
+
"This could be due to the cluster not having enough resources available. "
|
| 73 |
+
)
|
| 74 |
+
MSG_TUNE_INSUFFICIENT = (
|
| 75 |
+
"You asked for {asked_cpus} CPUs and {asked_gpus} GPUs per trial, "
|
| 76 |
+
"but the cluster only has {cluster_cpus} CPUs and {cluster_gpus} GPUs available. "
|
| 77 |
+
)
|
| 78 |
+
MSG_TUNE_END = (
|
| 79 |
+
"Stop the tuning and adjust the required resources (e.g. via the "
|
| 80 |
+
"`ScalingConfig` or `resources_per_trial`, or `num_workers` for rllib), "
|
| 81 |
+
"or add more resources to your cluster."
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
# TODO(xwjiang): Consider having a help page with more detailed instructions.
|
| 86 |
+
@lru_cache()
|
| 87 |
+
def _get_insufficient_resources_warning_msg(
|
| 88 |
+
for_train: bool = False, trial: Optional[Trial] = None
|
| 89 |
+
) -> str:
|
| 90 |
+
msg = "Ignore this message if the cluster is autoscaling. "
|
| 91 |
+
|
| 92 |
+
if for_train:
|
| 93 |
+
start = MSG_TRAIN_START
|
| 94 |
+
insufficient = MSG_TRAIN_INSUFFICIENT
|
| 95 |
+
end = MSG_TRAIN_END
|
| 96 |
+
else:
|
| 97 |
+
start = MSG_TUNE_START
|
| 98 |
+
insufficient = MSG_TUNE_INSUFFICIENT
|
| 99 |
+
end = MSG_TUNE_END
|
| 100 |
+
|
| 101 |
+
msg += start.format(wait_time=_get_insufficient_resources_warning_threshold())
|
| 102 |
+
|
| 103 |
+
if trial:
|
| 104 |
+
asked_cpus, asked_gpus = _get_trial_cpu_and_gpu(trial)
|
| 105 |
+
cluster_resources = _get_cluster_resources_no_autoscaler()
|
| 106 |
+
|
| 107 |
+
msg += insufficient.format(
|
| 108 |
+
asked_cpus=asked_cpus,
|
| 109 |
+
asked_gpus=asked_gpus,
|
| 110 |
+
cluster_cpus=cluster_resources.get("CPU", 0),
|
| 111 |
+
cluster_gpus=cluster_resources.get("GPU", 0),
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
msg += end
|
| 115 |
+
|
| 116 |
+
return msg
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
class _InsufficientResourcesManager:
|
| 120 |
+
"""Insufficient resources manager.
|
| 121 |
+
|
| 122 |
+
Makes best effort, conservative guesses about if Tune loop is stuck due to
|
| 123 |
+
infeasible resources. If so, outputs usability messages for users to
|
| 124 |
+
act upon.
|
| 125 |
+
"""
|
| 126 |
+
|
| 127 |
+
def __init__(self, for_train: bool = False):
|
| 128 |
+
# The information tracked across the life time of Tune loop.
|
| 129 |
+
self._no_running_trials_since = -1
|
| 130 |
+
self._last_trial_num = -1
|
| 131 |
+
self._for_train = for_train
|
| 132 |
+
|
| 133 |
+
def on_no_available_trials(self, all_trials):
|
| 134 |
+
"""Tracks information across the life of Tune loop and makes guesses
|
| 135 |
+
about if Tune loop is stuck due to infeasible resources.
|
| 136 |
+
If so, outputs certain warning messages.
|
| 137 |
+
The logic should be conservative, non-intrusive and informative.
|
| 138 |
+
For example, rate limiting is applied so that the message is not
|
| 139 |
+
spammy.
|
| 140 |
+
"""
|
| 141 |
+
# This is approximately saying we are not making progress.
|
| 142 |
+
if len(all_trials) == self._last_trial_num:
|
| 143 |
+
if self._no_running_trials_since == -1:
|
| 144 |
+
self._no_running_trials_since = time.monotonic()
|
| 145 |
+
elif (
|
| 146 |
+
time.monotonic() - self._no_running_trials_since
|
| 147 |
+
> _get_insufficient_resources_warning_threshold()
|
| 148 |
+
):
|
| 149 |
+
can_fulfill_any = any(
|
| 150 |
+
trial.status == Trial.PENDING and _can_fulfill_no_autoscaler(trial)
|
| 151 |
+
for trial in all_trials
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
if can_fulfill_any:
|
| 155 |
+
# If one trial can be fulfilled, it will be fulfilled eventually
|
| 156 |
+
self._no_running_trials_since = -1
|
| 157 |
+
return
|
| 158 |
+
|
| 159 |
+
# Otherwise, can fulfill none
|
| 160 |
+
msg = _get_insufficient_resources_warning_msg(
|
| 161 |
+
for_train=self._for_train, trial=all_trials[0]
|
| 162 |
+
)
|
| 163 |
+
logger.warning(msg)
|
| 164 |
+
self._no_running_trials_since = time.monotonic()
|
| 165 |
+
else:
|
| 166 |
+
self._no_running_trials_since = -1
|
| 167 |
+
self._last_trial_num = len(all_trials)
|
.venv/lib/python3.11/site-packages/ray/tune/execution/placement_groups.py
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import warnings
|
| 2 |
+
from typing import Dict, Optional
|
| 3 |
+
|
| 4 |
+
from ray.air.execution.resources.request import ResourceRequest
|
| 5 |
+
from ray.util.annotations import DeveloperAPI, PublicAPI
|
| 6 |
+
from ray.util.placement_group import placement_group
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
@PublicAPI(stability="beta")
|
| 10 |
+
class PlacementGroupFactory(ResourceRequest):
|
| 11 |
+
"""Wrapper class that creates placement groups for trials.
|
| 12 |
+
|
| 13 |
+
This function should be used to define resource requests for Ray Tune
|
| 14 |
+
trials. It holds the parameters to create
|
| 15 |
+
:ref:`placement groups <ray-placement-group-doc-ref>`.
|
| 16 |
+
At a minimum, this will hold at least one bundle specifying the
|
| 17 |
+
resource requirements for each trial:
|
| 18 |
+
|
| 19 |
+
.. code-block:: python
|
| 20 |
+
|
| 21 |
+
from ray import tune
|
| 22 |
+
|
| 23 |
+
tuner = tune.Tuner(
|
| 24 |
+
tune.with_resources(
|
| 25 |
+
train,
|
| 26 |
+
resources=tune.PlacementGroupFactory([
|
| 27 |
+
{"CPU": 1, "GPU": 0.5, "custom_resource": 2}
|
| 28 |
+
])
|
| 29 |
+
)
|
| 30 |
+
)
|
| 31 |
+
tuner.fit()
|
| 32 |
+
|
| 33 |
+
If the trial itself schedules further remote workers, the resource
|
| 34 |
+
requirements should be specified in additional bundles. You can also
|
| 35 |
+
pass the placement strategy for these bundles, e.g. to enforce
|
| 36 |
+
co-located placement:
|
| 37 |
+
|
| 38 |
+
.. code-block:: python
|
| 39 |
+
|
| 40 |
+
from ray import tune
|
| 41 |
+
|
| 42 |
+
tuner = tune.Tuner(
|
| 43 |
+
tune.with_resources(
|
| 44 |
+
train,
|
| 45 |
+
resources=tune.PlacementGroupFactory([
|
| 46 |
+
{"CPU": 1, "GPU": 0.5, "custom_resource": 2},
|
| 47 |
+
{"CPU": 2},
|
| 48 |
+
{"CPU": 2},
|
| 49 |
+
], strategy="PACK")
|
| 50 |
+
)
|
| 51 |
+
)
|
| 52 |
+
tuner.fit()
|
| 53 |
+
|
| 54 |
+
The example above will reserve 1 CPU, 0.5 GPUs and 2 custom_resources
|
| 55 |
+
for the trainable itself, and reserve another 2 bundles of 2 CPUs each.
|
| 56 |
+
The trial will only start when all these resources are available. This
|
| 57 |
+
could be used e.g. if you had one learner running in the main trainable
|
| 58 |
+
that schedules two remote workers that need access to 2 CPUs each.
|
| 59 |
+
|
| 60 |
+
If the trainable itself doesn't require resources.
|
| 61 |
+
You can specify it as:
|
| 62 |
+
|
| 63 |
+
.. code-block:: python
|
| 64 |
+
|
| 65 |
+
from ray import tune
|
| 66 |
+
|
| 67 |
+
tuner = tune.Tuner(
|
| 68 |
+
tune.with_resources(
|
| 69 |
+
train,
|
| 70 |
+
resources=tune.PlacementGroupFactory([
|
| 71 |
+
{},
|
| 72 |
+
{"CPU": 2},
|
| 73 |
+
{"CPU": 2},
|
| 74 |
+
], strategy="PACK")
|
| 75 |
+
)
|
| 76 |
+
)
|
| 77 |
+
tuner.fit()
|
| 78 |
+
|
| 79 |
+
Args:
|
| 80 |
+
bundles: A list of bundles which
|
| 81 |
+
represent the resources requirements.
|
| 82 |
+
strategy: The strategy to create the placement group.
|
| 83 |
+
|
| 84 |
+
- "PACK": Packs Bundles into as few nodes as possible.
|
| 85 |
+
- "SPREAD": Places Bundles across distinct nodes as even as possible.
|
| 86 |
+
- "STRICT_PACK": Packs Bundles into one node. The group is
|
| 87 |
+
not allowed to span multiple nodes.
|
| 88 |
+
- "STRICT_SPREAD": Packs Bundles across distinct nodes.
|
| 89 |
+
*args: Passed to the call of ``placement_group()``
|
| 90 |
+
**kwargs: Passed to the call of ``placement_group()``
|
| 91 |
+
|
| 92 |
+
"""
|
| 93 |
+
|
| 94 |
+
def __call__(self, *args, **kwargs):
|
| 95 |
+
warnings.warn(
|
| 96 |
+
"Calling PlacementGroupFactory objects is deprecated. Use "
|
| 97 |
+
"`to_placement_group()` instead.",
|
| 98 |
+
DeprecationWarning,
|
| 99 |
+
)
|
| 100 |
+
kwargs.update(self._bound.kwargs)
|
| 101 |
+
# Call with bounded *args and **kwargs
|
| 102 |
+
return placement_group(*self._bound.args, **kwargs)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
@DeveloperAPI
|
| 106 |
+
def resource_dict_to_pg_factory(spec: Optional[Dict[str, float]] = None):
|
| 107 |
+
"""Translates resource dict into PlacementGroupFactory."""
|
| 108 |
+
spec = spec or {"cpu": 1}
|
| 109 |
+
|
| 110 |
+
spec = spec.copy()
|
| 111 |
+
|
| 112 |
+
cpus = spec.pop("cpu", spec.pop("CPU", 0.0))
|
| 113 |
+
gpus = spec.pop("gpu", spec.pop("GPU", 0.0))
|
| 114 |
+
memory = spec.pop("memory", 0.0)
|
| 115 |
+
|
| 116 |
+
# If there is a custom_resources key, use as base for bundle
|
| 117 |
+
bundle = {k: v for k, v in spec.pop("custom_resources", {}).items()}
|
| 118 |
+
|
| 119 |
+
# Otherwise, consider all other keys as custom resources
|
| 120 |
+
if not bundle:
|
| 121 |
+
bundle = spec
|
| 122 |
+
|
| 123 |
+
bundle.update(
|
| 124 |
+
{
|
| 125 |
+
"CPU": cpus,
|
| 126 |
+
"GPU": gpus,
|
| 127 |
+
"memory": memory,
|
| 128 |
+
}
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
return PlacementGroupFactory([bundle])
|
.venv/lib/python3.11/site-packages/ray/tune/execution/tune_controller.py
ADDED
|
@@ -0,0 +1,2181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import json
|
| 3 |
+
import logging
|
| 4 |
+
import os
|
| 5 |
+
import time
|
| 6 |
+
import traceback
|
| 7 |
+
import warnings
|
| 8 |
+
from collections import defaultdict, deque
|
| 9 |
+
from datetime import datetime
|
| 10 |
+
from functools import partial
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
|
| 13 |
+
|
| 14 |
+
import ray
|
| 15 |
+
from ray.air import ResourceRequest
|
| 16 |
+
from ray.air.constants import TIME_THIS_ITER_S
|
| 17 |
+
from ray.air.execution import PlacementGroupResourceManager, ResourceManager
|
| 18 |
+
from ray.air.execution._internal import RayActorManager, TrackedActor
|
| 19 |
+
from ray.exceptions import RayActorError, RayTaskError
|
| 20 |
+
from ray.train import CheckpointConfig
|
| 21 |
+
from ray.train._internal.session import _FutureTrainingResult, _TrainingResult
|
| 22 |
+
from ray.train._internal.storage import StorageContext
|
| 23 |
+
from ray.tune.callback import Callback, CallbackList
|
| 24 |
+
from ray.tune.error import TuneError, _AbortTrialExecution, _TuneStopTrialError
|
| 25 |
+
from ray.tune.execution.class_cache import _ActorClassCache
|
| 26 |
+
from ray.tune.execution.experiment_state import (
|
| 27 |
+
_ExperimentCheckpointManager,
|
| 28 |
+
_find_newest_experiment_checkpoint,
|
| 29 |
+
)
|
| 30 |
+
from ray.tune.execution.insufficient_resources_manager import (
|
| 31 |
+
_InsufficientResourcesManager,
|
| 32 |
+
)
|
| 33 |
+
from ray.tune.execution.placement_groups import PlacementGroupFactory
|
| 34 |
+
from ray.tune.experiment import Experiment, Trial
|
| 35 |
+
from ray.tune.experiment.trial import (
|
| 36 |
+
_change_working_directory,
|
| 37 |
+
_get_trainable_kwargs,
|
| 38 |
+
_Location,
|
| 39 |
+
_noop_logger_creator,
|
| 40 |
+
_TrialInfo,
|
| 41 |
+
)
|
| 42 |
+
from ray.tune.result import (
|
| 43 |
+
DEBUG_METRICS,
|
| 44 |
+
DEFAULT_METRIC,
|
| 45 |
+
DONE,
|
| 46 |
+
RESULT_DUPLICATE,
|
| 47 |
+
SHOULD_CHECKPOINT,
|
| 48 |
+
STDERR_FILE,
|
| 49 |
+
STDOUT_FILE,
|
| 50 |
+
TRIAL_INFO,
|
| 51 |
+
)
|
| 52 |
+
from ray.tune.schedulers import FIFOScheduler, TrialScheduler
|
| 53 |
+
from ray.tune.search import BasicVariantGenerator, SearchAlgorithm
|
| 54 |
+
from ray.tune.stopper import NoopStopper, Stopper
|
| 55 |
+
from ray.tune.tune_config import ResumeConfig
|
| 56 |
+
from ray.tune.utils import flatten_dict, warn_if_slow
|
| 57 |
+
from ray.tune.utils.log import Verbosity, _dedup_logs, has_verbosity
|
| 58 |
+
from ray.tune.utils.object_cache import _ObjectCache
|
| 59 |
+
from ray.tune.utils.resource_updater import _ResourceUpdater
|
| 60 |
+
from ray.tune.utils.serialization import TuneFunctionDecoder, TuneFunctionEncoder
|
| 61 |
+
from ray.util.annotations import DeveloperAPI
|
| 62 |
+
from ray.util.debug import log_once
|
| 63 |
+
|
| 64 |
+
logger = logging.getLogger(__name__)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
@DeveloperAPI
|
| 68 |
+
class TuneController:
|
| 69 |
+
CKPT_FILE_TMPL = "experiment_state-{}.json"
|
| 70 |
+
RAISE = "RAISE"
|
| 71 |
+
|
| 72 |
+
def __init__(
|
| 73 |
+
self,
|
| 74 |
+
*,
|
| 75 |
+
search_alg: Optional[SearchAlgorithm] = None,
|
| 76 |
+
placeholder_resolvers: Optional[Dict[Tuple, Any]] = None,
|
| 77 |
+
scheduler: Optional[TrialScheduler] = None,
|
| 78 |
+
stopper: Optional[Stopper] = None,
|
| 79 |
+
resume_config: Optional[ResumeConfig] = None,
|
| 80 |
+
fail_fast: bool = False,
|
| 81 |
+
checkpoint_period: Union[str, int] = None,
|
| 82 |
+
callbacks: Optional[List[Callback]] = None,
|
| 83 |
+
metric: Optional[str] = None,
|
| 84 |
+
trial_checkpoint_config: Optional[CheckpointConfig] = None,
|
| 85 |
+
storage: Optional[StorageContext] = None,
|
| 86 |
+
reuse_actors: bool = False,
|
| 87 |
+
resource_manager_factory: Optional[Callable[[], ResourceManager]] = None,
|
| 88 |
+
_trainer_api: bool = False,
|
| 89 |
+
):
|
| 90 |
+
if resource_manager_factory:
|
| 91 |
+
resource_manager = resource_manager_factory()
|
| 92 |
+
else:
|
| 93 |
+
resource_manager = PlacementGroupResourceManager()
|
| 94 |
+
|
| 95 |
+
self._actor_manager = RayActorManager(resource_manager=resource_manager)
|
| 96 |
+
|
| 97 |
+
self._class_cache = _ActorClassCache()
|
| 98 |
+
|
| 99 |
+
# Resource status
|
| 100 |
+
self._resource_updater = _ResourceUpdater(None)
|
| 101 |
+
|
| 102 |
+
# Actor <-> Trial mappings
|
| 103 |
+
self._actor_to_trial: Dict[TrackedActor, Trial] = {}
|
| 104 |
+
self._trial_to_actor: Dict[Trial, TrackedActor] = {}
|
| 105 |
+
|
| 106 |
+
# Resources <-> Trial
|
| 107 |
+
self._resources_to_pending_trials: Dict[
|
| 108 |
+
ResourceRequest, Set[Trial]
|
| 109 |
+
] = defaultdict(set)
|
| 110 |
+
|
| 111 |
+
# Keep track of actor states
|
| 112 |
+
self._pending_trials: Set[Trial] = set()
|
| 113 |
+
self._pending_trials_list: List[Trial] = []
|
| 114 |
+
|
| 115 |
+
self._running_trials: Set[Trial] = set()
|
| 116 |
+
|
| 117 |
+
self._paused_trials: Set[Trial] = set()
|
| 118 |
+
|
| 119 |
+
self._stopped_trials: Set[Trial] = set()
|
| 120 |
+
self._failed_trials: Set[Trial] = set()
|
| 121 |
+
|
| 122 |
+
self._resetting_trials: Set[Trial] = set()
|
| 123 |
+
self._staged_trials: Set[Trial] = set()
|
| 124 |
+
|
| 125 |
+
# Removed actors
|
| 126 |
+
self._started_actors: Set[TrackedActor] = set()
|
| 127 |
+
|
| 128 |
+
# Map of tracked actors -> timestamp
|
| 129 |
+
# The timestamp is when we requested the stop.
|
| 130 |
+
# We track these actors here to force a
|
| 131 |
+
# cleanup after some time (as they might be hanging).
|
| 132 |
+
# Todo: This timeout logic should be moved into the actor manager.
|
| 133 |
+
# This map is populated whenever we request an actor stop:
|
| 134 |
+
# - Regular STOP decision
|
| 135 |
+
# - Removing an actor because its trial REUSEs a different trial's actor
|
| 136 |
+
# - Removing a cached actor because it's not needed anymore
|
| 137 |
+
# Actors are only tracked in this map if they actually started (not if they
|
| 138 |
+
# were only requested but never started).
|
| 139 |
+
# Actors are removed from this map:
|
| 140 |
+
# - When the STOP resolved and the actor actually stopped
|
| 141 |
+
# - When they are forcefully cleaned up after the timeout.
|
| 142 |
+
self._stopping_actors: Dict[TrackedActor, float] = {}
|
| 143 |
+
self._earliest_stopping_actor: float = float("inf")
|
| 144 |
+
self._actor_cleanup_timeout: int = int(
|
| 145 |
+
os.environ.get("TUNE_FORCE_TRIAL_CLEANUP_S", "600")
|
| 146 |
+
)
|
| 147 |
+
self._actor_force_cleanup_timeout: int = 10
|
| 148 |
+
|
| 149 |
+
# Reuse actors
|
| 150 |
+
self._reuse_actors = reuse_actors
|
| 151 |
+
self._actor_cache = _ObjectCache(may_keep_one=True)
|
| 152 |
+
|
| 153 |
+
# Trial metadata for experiment checkpoints
|
| 154 |
+
self._trials_to_cache: Set[Trial] = set()
|
| 155 |
+
self._trial_metadata: Dict[str, str] = {}
|
| 156 |
+
|
| 157 |
+
# TRAINING
|
| 158 |
+
self._buffer_length = int(os.getenv("TUNE_RESULT_BUFFER_LENGTH", 1))
|
| 159 |
+
self._buffer_min_time_s = float(os.getenv("TUNE_RESULT_BUFFER_MIN_TIME_S", 0.0))
|
| 160 |
+
self._buffer_max_time_s = float(
|
| 161 |
+
os.getenv("TUNE_RESULT_BUFFER_MAX_TIME_S", 100.0)
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
# Legacy TrialRunner init
|
| 165 |
+
self._search_alg = search_alg or BasicVariantGenerator()
|
| 166 |
+
self._placeholder_resolvers = placeholder_resolvers
|
| 167 |
+
self._scheduler_alg = scheduler or FIFOScheduler()
|
| 168 |
+
self._callbacks = CallbackList(callbacks or [])
|
| 169 |
+
self._insufficient_resources_manager = _InsufficientResourcesManager(
|
| 170 |
+
for_train=_trainer_api
|
| 171 |
+
)
|
| 172 |
+
self._pending_trial_queue_times = {}
|
| 173 |
+
|
| 174 |
+
self._max_pending_trials = _get_max_pending_trials(self._search_alg)
|
| 175 |
+
|
| 176 |
+
self._storage = storage
|
| 177 |
+
self._metric = metric
|
| 178 |
+
|
| 179 |
+
self._total_time = 0
|
| 180 |
+
self._iteration = 0
|
| 181 |
+
self._has_errored = False
|
| 182 |
+
self._fail_fast = fail_fast
|
| 183 |
+
if isinstance(self._fail_fast, str):
|
| 184 |
+
self._fail_fast = self._fail_fast.upper()
|
| 185 |
+
if self._fail_fast == self.RAISE:
|
| 186 |
+
warnings.warn(
|
| 187 |
+
"fail_fast='raise' detected. Be careful when using this "
|
| 188 |
+
"mode as resources (such as Ray processes, "
|
| 189 |
+
"file descriptors, and temporary files) may not be "
|
| 190 |
+
"cleaned up properly. To use "
|
| 191 |
+
"a safer mode, use fail_fast=True."
|
| 192 |
+
)
|
| 193 |
+
else:
|
| 194 |
+
raise ValueError(
|
| 195 |
+
"fail_fast must be one of {bool, RAISE}. " f"Got {self._fail_fast}."
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
self._print_trial_errors = bool(
|
| 199 |
+
int(os.environ.get("TUNE_PRINT_ALL_TRIAL_ERRORS", "1"))
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
self._trials: List[Trial] = []
|
| 203 |
+
self._live_trials: Set[Trial] = set() # Set of non-terminated trials
|
| 204 |
+
self._cached_trial_decisions = {}
|
| 205 |
+
self._queued_trial_decisions = {}
|
| 206 |
+
|
| 207 |
+
self._stop_queue = []
|
| 208 |
+
self._should_stop_experiment = False # used by TuneServer
|
| 209 |
+
|
| 210 |
+
self._stopper = stopper or NoopStopper()
|
| 211 |
+
|
| 212 |
+
self._start_time = time.time()
|
| 213 |
+
|
| 214 |
+
self._session_str = datetime.fromtimestamp(self._start_time).strftime(
|
| 215 |
+
"%Y-%m-%d_%H-%M-%S"
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
if checkpoint_period is None:
|
| 219 |
+
checkpoint_period = os.getenv("TUNE_GLOBAL_CHECKPOINT_S", "auto")
|
| 220 |
+
|
| 221 |
+
self._checkpoint_period = checkpoint_period
|
| 222 |
+
self._trial_checkpoint_config = trial_checkpoint_config or CheckpointConfig()
|
| 223 |
+
self._checkpoint_manager = self._create_checkpoint_manager()
|
| 224 |
+
|
| 225 |
+
self._resumed = False
|
| 226 |
+
|
| 227 |
+
if resume_config is not None:
|
| 228 |
+
# Use the metadata file to restore TuneController state
|
| 229 |
+
try:
|
| 230 |
+
self.resume(resume_config=resume_config)
|
| 231 |
+
self._resumed = True
|
| 232 |
+
except Exception as e:
|
| 233 |
+
if has_verbosity(Verbosity.V3_TRIAL_DETAILS):
|
| 234 |
+
logger.error(str(e))
|
| 235 |
+
logger.exception("Failed to restore the run state.")
|
| 236 |
+
if self._fail_fast:
|
| 237 |
+
raise
|
| 238 |
+
logger.info("Restarting experiment.")
|
| 239 |
+
else:
|
| 240 |
+
logger.debug("Starting a new experiment.")
|
| 241 |
+
|
| 242 |
+
def _wrapped(self):
|
| 243 |
+
"""Return wrapped tune controller to be passed to scheduler/searchers."""
|
| 244 |
+
return TrialRunnerWrapper(
|
| 245 |
+
self,
|
| 246 |
+
trial_executor=_FakeRayTrialExecutor(self),
|
| 247 |
+
runner_whitelist_attr={
|
| 248 |
+
"search_alg",
|
| 249 |
+
"get_trials",
|
| 250 |
+
"get_live_trials",
|
| 251 |
+
"_set_trial_status",
|
| 252 |
+
"pause_trial",
|
| 253 |
+
"stop_trial",
|
| 254 |
+
"_schedule_trial_save",
|
| 255 |
+
},
|
| 256 |
+
executor_whitelist_attr={
|
| 257 |
+
"has_resources_for_trial",
|
| 258 |
+
"pause_trial",
|
| 259 |
+
"save",
|
| 260 |
+
"_resource_updater",
|
| 261 |
+
},
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
@property
|
| 265 |
+
def resumed(self):
|
| 266 |
+
return self._resumed
|
| 267 |
+
|
| 268 |
+
@property
|
| 269 |
+
def search_alg(self):
|
| 270 |
+
return self._search_alg
|
| 271 |
+
|
| 272 |
+
@property
|
| 273 |
+
def scheduler_alg(self):
|
| 274 |
+
return self._scheduler_alg
|
| 275 |
+
|
| 276 |
+
def setup_experiments(
|
| 277 |
+
self, experiments: List[Experiment], total_num_samples: int
|
| 278 |
+
) -> None:
|
| 279 |
+
"""Obtains any necessary information from experiments.
|
| 280 |
+
|
| 281 |
+
Mainly used to setup callbacks.
|
| 282 |
+
|
| 283 |
+
Args:
|
| 284 |
+
experiments: List of Experiments
|
| 285 |
+
to use.
|
| 286 |
+
total_num_samples: Total number of samples
|
| 287 |
+
factoring in grid search samplers.
|
| 288 |
+
"""
|
| 289 |
+
experiment = experiments[0]
|
| 290 |
+
spec = experiment.public_spec if experiment else {}
|
| 291 |
+
spec["total_num_samples"] = total_num_samples
|
| 292 |
+
self._callbacks.setup(**spec)
|
| 293 |
+
|
| 294 |
+
def end_experiment_callbacks(self) -> None:
|
| 295 |
+
"""Calls ``on_experiment_end`` method in callbacks."""
|
| 296 |
+
self._callbacks.on_experiment_end(trials=self._trials)
|
| 297 |
+
|
| 298 |
+
@property
|
| 299 |
+
def experiment_state_file_name(self) -> str:
|
| 300 |
+
return self.CKPT_FILE_TMPL.format(self._session_str)
|
| 301 |
+
|
| 302 |
+
@property
|
| 303 |
+
def experiment_state_path(self) -> str:
|
| 304 |
+
"""Returns the local experiment checkpoint path."""
|
| 305 |
+
return Path(
|
| 306 |
+
self._storage.experiment_driver_staging_path,
|
| 307 |
+
self.experiment_state_file_name,
|
| 308 |
+
).as_posix()
|
| 309 |
+
|
| 310 |
+
@property
|
| 311 |
+
def experiment_path(self) -> str:
|
| 312 |
+
return self._storage.experiment_fs_path
|
| 313 |
+
|
| 314 |
+
def _create_checkpoint_manager(self):
|
| 315 |
+
return _ExperimentCheckpointManager(
|
| 316 |
+
storage=self._storage,
|
| 317 |
+
checkpoint_period=self._checkpoint_period,
|
| 318 |
+
sync_every_n_trial_checkpoints=self._trial_checkpoint_config.num_to_keep,
|
| 319 |
+
)
|
| 320 |
+
|
| 321 |
+
def save_to_dir(self):
|
| 322 |
+
"""Save TuneController state to the local staging experiment directory.
|
| 323 |
+
|
| 324 |
+
This includes:
|
| 325 |
+
- trial states
|
| 326 |
+
- TuneController internal state (all the serializable attributes)
|
| 327 |
+
- the searcher state
|
| 328 |
+
- the callback states
|
| 329 |
+
"""
|
| 330 |
+
# Get state from trial executor and runner
|
| 331 |
+
runner_state = {
|
| 332 |
+
# Trials
|
| 333 |
+
"trial_data": list(self._get_trial_checkpoints().values()),
|
| 334 |
+
# Experiment data
|
| 335 |
+
"runner_data": self.__getstate__(),
|
| 336 |
+
# Metadata
|
| 337 |
+
"stats": {"start_time": self._start_time},
|
| 338 |
+
}
|
| 339 |
+
|
| 340 |
+
driver_staging_path = self._storage.experiment_driver_staging_path
|
| 341 |
+
os.makedirs(driver_staging_path, exist_ok=True)
|
| 342 |
+
with open(
|
| 343 |
+
Path(driver_staging_path, self.experiment_state_file_name),
|
| 344 |
+
"w",
|
| 345 |
+
) as f:
|
| 346 |
+
json.dump(runner_state, f, cls=TuneFunctionEncoder)
|
| 347 |
+
|
| 348 |
+
self._search_alg.save_to_dir(driver_staging_path, session_str=self._session_str)
|
| 349 |
+
self._callbacks.save_to_dir(driver_staging_path, session_str=self._session_str)
|
| 350 |
+
|
| 351 |
+
def checkpoint(self, force: bool = False, wait: bool = False):
|
| 352 |
+
self._checkpoint_manager.sync_up_experiment_state(
|
| 353 |
+
save_fn=self.save_to_dir, force=force, wait=wait
|
| 354 |
+
)
|
| 355 |
+
|
| 356 |
+
def _requeue_restored_trials(
|
| 357 |
+
self, trials: List[Trial], resume_config: ResumeConfig
|
| 358 |
+
):
|
| 359 |
+
# Set trial statuses according to the resume configuration
|
| 360 |
+
for trial in sorted(
|
| 361 |
+
trials, key=lambda t: t.run_metadata.last_result_time, reverse=True
|
| 362 |
+
):
|
| 363 |
+
if trial.status == Trial.ERROR:
|
| 364 |
+
resume_type = resume_config.errored
|
| 365 |
+
elif trial.status == Trial.TERMINATED:
|
| 366 |
+
resume_type = resume_config.finished
|
| 367 |
+
else: # Unfinished (PENDING, RUNNING, PAUSED)
|
| 368 |
+
resume_type = resume_config.unfinished
|
| 369 |
+
|
| 370 |
+
trial_to_add = None
|
| 371 |
+
if resume_type == ResumeConfig.ResumeType.RESUME:
|
| 372 |
+
# Keep trial ID on resume
|
| 373 |
+
trial_to_add = trial
|
| 374 |
+
trial_to_add.run_metadata.error_filename = None
|
| 375 |
+
trial_to_add.run_metadata.pickled_error_filename = None
|
| 376 |
+
trial_to_add.set_status(Trial.PENDING)
|
| 377 |
+
elif resume_type == ResumeConfig.ResumeType.RESTART:
|
| 378 |
+
trial_to_add = trial.reset()
|
| 379 |
+
trial_to_add.restore_path = None
|
| 380 |
+
elif resume_type == ResumeConfig.ResumeType.SKIP:
|
| 381 |
+
trial_to_add = trial
|
| 382 |
+
if trial_to_add.status != Trial.ERROR:
|
| 383 |
+
# Set the status to terminated to skip it.
|
| 384 |
+
# Keep errored trial status as ERROR.
|
| 385 |
+
trial_to_add.set_status(Trial.TERMINATED)
|
| 386 |
+
else:
|
| 387 |
+
raise ValueError(f"Unknown resume type: {resume_type}")
|
| 388 |
+
assert trial_to_add is not None
|
| 389 |
+
|
| 390 |
+
self.add_trial(trial_to_add)
|
| 391 |
+
|
| 392 |
+
def _restore_trials(self, experiment_state: Dict) -> List[Trial]:
|
| 393 |
+
trials = []
|
| 394 |
+
for trial_json_state, trial_runtime_metadata in experiment_state["trial_data"]:
|
| 395 |
+
trial = Trial.from_json_state(trial_json_state)
|
| 396 |
+
trial.restore_run_metadata(trial_runtime_metadata)
|
| 397 |
+
|
| 398 |
+
# The following properties may be updated on restoration
|
| 399 |
+
# Ex: moved local/cloud experiment directory
|
| 400 |
+
|
| 401 |
+
# Propagate updated storage ctx properties to the trial's restored copy.
|
| 402 |
+
new_storage = copy.copy(trial.storage)
|
| 403 |
+
new_storage.storage_filesystem = self._storage.storage_filesystem
|
| 404 |
+
new_storage.storage_fs_path = self._storage.storage_fs_path
|
| 405 |
+
new_storage.experiment_dir_name = self._storage.experiment_dir_name
|
| 406 |
+
|
| 407 |
+
# ATTN: `trial.set_storage` is used intentionally, since it
|
| 408 |
+
# also updates the absolute paths and filesystem of tracked checkpoints.
|
| 409 |
+
trial.set_storage(new_storage)
|
| 410 |
+
|
| 411 |
+
# Avoid creating logdir in client mode for returned trial results,
|
| 412 |
+
# since the dir might not be creatable locally.
|
| 413 |
+
# TODO(ekl) this is kind of a hack.
|
| 414 |
+
if not ray.util.client.ray.is_connected():
|
| 415 |
+
trial.init_local_path() # Create logdir if it does not exist
|
| 416 |
+
|
| 417 |
+
trials.append(trial)
|
| 418 |
+
|
| 419 |
+
# NOTE: The restored run should reuse the same driver staging directory.
|
| 420 |
+
self._storage._timestamp = trials[0].storage._timestamp
|
| 421 |
+
|
| 422 |
+
return trials
|
| 423 |
+
|
| 424 |
+
def resume(self, resume_config: ResumeConfig):
|
| 425 |
+
"""Resumes all checkpointed trials from previous run.
|
| 426 |
+
|
| 427 |
+
Requires user to manually re-register their objects. Also stops
|
| 428 |
+
all ongoing trials.
|
| 429 |
+
"""
|
| 430 |
+
# 1. Restore TuneController state
|
| 431 |
+
# Find newest state file
|
| 432 |
+
newest_state_path = _find_newest_experiment_checkpoint(
|
| 433 |
+
self._storage.experiment_fs_path, fs=self._storage.storage_filesystem
|
| 434 |
+
)
|
| 435 |
+
|
| 436 |
+
if newest_state_path is None:
|
| 437 |
+
raise ValueError(
|
| 438 |
+
f"Tried to resume experiment from directory "
|
| 439 |
+
f"'{self._storage.experiment_fs_path}', but no "
|
| 440 |
+
f"experiment state file of the form '{TuneController.CKPT_FILE_TMPL}' "
|
| 441 |
+
"was found. This is expected if you are launching a new experiment."
|
| 442 |
+
)
|
| 443 |
+
|
| 444 |
+
logger.info(
|
| 445 |
+
"Restoring the run from the latest experiment state file: "
|
| 446 |
+
f"{Path(newest_state_path).name}"
|
| 447 |
+
)
|
| 448 |
+
with self._storage.storage_filesystem.open_input_stream(newest_state_path) as f:
|
| 449 |
+
experiment_state = json.loads(f.readall(), cls=TuneFunctionDecoder)
|
| 450 |
+
|
| 451 |
+
self.__setstate__(experiment_state["runner_data"])
|
| 452 |
+
|
| 453 |
+
# 2. Get the trial states that the run left off at.
|
| 454 |
+
trials = self._restore_trials(experiment_state)
|
| 455 |
+
|
| 456 |
+
# 3. Restore search algorithm and callback state
|
| 457 |
+
# Download the search algorithm and callback state to the driver staging dir.
|
| 458 |
+
self._checkpoint_manager.sync_down_experiment_state()
|
| 459 |
+
|
| 460 |
+
driver_staging_dir = self._storage.experiment_driver_staging_path
|
| 461 |
+
if self._search_alg.has_checkpoint(driver_staging_dir):
|
| 462 |
+
self._search_alg.restore_from_dir(driver_staging_dir)
|
| 463 |
+
|
| 464 |
+
if self._callbacks.can_restore(driver_staging_dir):
|
| 465 |
+
self._callbacks.restore_from_dir(driver_staging_dir)
|
| 466 |
+
|
| 467 |
+
# 4. Re-queue trials as needed, depending on their status.
|
| 468 |
+
self._requeue_restored_trials(trials, resume_config)
|
| 469 |
+
|
| 470 |
+
def update_max_pending_trials(self, max_pending_trials: Optional[int] = None):
|
| 471 |
+
self._max_pending_trials = max_pending_trials or _get_max_pending_trials(
|
| 472 |
+
self._search_alg
|
| 473 |
+
)
|
| 474 |
+
|
| 475 |
+
def update_pending_trial_resources(
|
| 476 |
+
self, resources: Union[dict, PlacementGroupFactory]
|
| 477 |
+
):
|
| 478 |
+
"""Update trial resources when resuming from checkpoint.
|
| 479 |
+
|
| 480 |
+
Only updating the pending ones.
|
| 481 |
+
"""
|
| 482 |
+
assert resources
|
| 483 |
+
if isinstance(resources, dict) and "gpu" not in resources:
|
| 484 |
+
resources["gpu"] = 0
|
| 485 |
+
for trial in self._trials:
|
| 486 |
+
if trial.status == Trial.PENDING:
|
| 487 |
+
trial.update_resources(resources=resources)
|
| 488 |
+
|
| 489 |
+
def is_finished(self):
|
| 490 |
+
"""Returns whether all trials have finished running."""
|
| 491 |
+
# The checks here are partly redundant but optimized for quick
|
| 492 |
+
# evaluation. Specifically, if there are live trials, we check
|
| 493 |
+
# these live trials first. Only if none of the live trials is
|
| 494 |
+
# live anymore do we loop over all trials for a final check.
|
| 495 |
+
trials_done = (
|
| 496 |
+
len(self._live_trials) == 0
|
| 497 |
+
or all(trial.is_finished() for trial in self._live_trials)
|
| 498 |
+
) and all(trial.is_finished() for trial in self._trials)
|
| 499 |
+
return trials_done and self._search_alg.is_finished()
|
| 500 |
+
|
| 501 |
+
def get_trial(self, tid):
|
| 502 |
+
trial = [t for t in self._trials if t.trial_id == tid]
|
| 503 |
+
return trial[0] if trial else None
|
| 504 |
+
|
| 505 |
+
def get_trials(self):
|
| 506 |
+
"""Returns the list of trials managed by this TrialRunner.
|
| 507 |
+
|
| 508 |
+
Note that the caller usually should not mutate trial state directly.
|
| 509 |
+
"""
|
| 510 |
+
return self._trials
|
| 511 |
+
|
| 512 |
+
def get_live_trials(self):
|
| 513 |
+
"""Returns the set of trials that are not in Trial.TERMINATED state."""
|
| 514 |
+
return self._live_trials
|
| 515 |
+
|
| 516 |
+
def add_trial(self, trial: Trial):
|
| 517 |
+
"""Adds a new trial to this TrialRunner.
|
| 518 |
+
|
| 519 |
+
Trials may be added at any time.
|
| 520 |
+
|
| 521 |
+
Args:
|
| 522 |
+
trial: Trial to queue.
|
| 523 |
+
"""
|
| 524 |
+
# If the config map has had all the references replaced with placeholders,
|
| 525 |
+
# resolve them before adding the trial.
|
| 526 |
+
if self._placeholder_resolvers:
|
| 527 |
+
trial.resolve_config_placeholders(self._placeholder_resolvers)
|
| 528 |
+
|
| 529 |
+
# With trial.config resolved, create placement group factory if needed.
|
| 530 |
+
trial.create_placement_group_factory()
|
| 531 |
+
|
| 532 |
+
self._trials.append(trial)
|
| 533 |
+
if trial.status != Trial.TERMINATED:
|
| 534 |
+
self._live_trials.add(trial)
|
| 535 |
+
with warn_if_slow("scheduler.on_trial_add"):
|
| 536 |
+
self._scheduler_alg.on_trial_add(self._wrapped(), trial)
|
| 537 |
+
self._mark_trial_to_checkpoint(trial)
|
| 538 |
+
|
| 539 |
+
logger.debug(f"Adding trial {trial} with status {trial.status}")
|
| 540 |
+
|
| 541 |
+
status_str_map = {
|
| 542 |
+
Trial.PENDING: self._pending_trials,
|
| 543 |
+
Trial.RUNNING: self._running_trials,
|
| 544 |
+
Trial.PAUSED: self._paused_trials,
|
| 545 |
+
Trial.TERMINATED: self._stopped_trials,
|
| 546 |
+
Trial.ERROR: self._failed_trials,
|
| 547 |
+
}
|
| 548 |
+
|
| 549 |
+
status_str_map[trial.status].add(trial)
|
| 550 |
+
|
| 551 |
+
if trial.status == Trial.PENDING:
|
| 552 |
+
self._pending_trials_list.append(trial)
|
| 553 |
+
self._resources_to_pending_trials[trial.placement_group_factory].add(trial)
|
| 554 |
+
|
| 555 |
+
def _update_trial_queue(self, blocking: bool = False, timeout: int = 600) -> bool:
|
| 556 |
+
"""Adds next trials to queue if possible.
|
| 557 |
+
|
| 558 |
+
Note that the timeout is currently unexposed to the user.
|
| 559 |
+
|
| 560 |
+
Args:
|
| 561 |
+
blocking: Blocks until either a trial is available
|
| 562 |
+
or is_finished (timeout or search algorithm finishes).
|
| 563 |
+
timeout: Seconds before blocking times out.
|
| 564 |
+
|
| 565 |
+
Returns:
|
| 566 |
+
Boolean indicating if a new trial was created or not.
|
| 567 |
+
"""
|
| 568 |
+
trial = self._search_alg.next_trial()
|
| 569 |
+
if blocking and not trial:
|
| 570 |
+
start = time.time()
|
| 571 |
+
# Checking `is_finished` instead of _search_alg.is_finished
|
| 572 |
+
# is fine because blocking only occurs if all trials are
|
| 573 |
+
# finished and search_algorithm is not yet finished
|
| 574 |
+
while (
|
| 575 |
+
not trial and not self.is_finished() and time.time() - start < timeout
|
| 576 |
+
):
|
| 577 |
+
logger.debug("Blocking for next trial...")
|
| 578 |
+
trial = self._search_alg.next_trial()
|
| 579 |
+
time.sleep(1)
|
| 580 |
+
|
| 581 |
+
if trial:
|
| 582 |
+
self.add_trial(trial)
|
| 583 |
+
return True
|
| 584 |
+
|
| 585 |
+
return False
|
| 586 |
+
|
| 587 |
+
def _used_resources_string(self) -> str:
|
| 588 |
+
allocated_resources = self._actor_manager.get_live_actors_resources()
|
| 589 |
+
|
| 590 |
+
return self._resource_updater.debug_string(allocated_resources)
|
| 591 |
+
|
| 592 |
+
def on_step_begin(self):
|
| 593 |
+
self._resource_updater.update_avail_resources()
|
| 594 |
+
|
| 595 |
+
def on_step_end(self):
|
| 596 |
+
self._cleanup_cached_actors(force_all=False)
|
| 597 |
+
self._cleanup_stopping_actors(force_all=False)
|
| 598 |
+
|
| 599 |
+
def _cleanup_cached_actors(self, force_all: bool = False):
|
| 600 |
+
if (
|
| 601 |
+
self._search_alg.is_finished()
|
| 602 |
+
and not self._staged_trials
|
| 603 |
+
and self._actor_cache.total_max_objects == 0
|
| 604 |
+
):
|
| 605 |
+
# If there are no more trials coming in, no trials are pending execution,
|
| 606 |
+
# and we don't explicitly want to cache objects, we can evict the full
|
| 607 |
+
# cache.
|
| 608 |
+
force_all = True
|
| 609 |
+
|
| 610 |
+
for tracked_actor in self._actor_cache.flush_cached_objects(
|
| 611 |
+
force_all=force_all
|
| 612 |
+
):
|
| 613 |
+
logger.debug(f"Cleaning up cached actor: {tracked_actor}")
|
| 614 |
+
# Unset termination callbacks as no trial is associated
|
| 615 |
+
tracked_actor.set_on_stop(None)
|
| 616 |
+
tracked_actor.set_on_error(None)
|
| 617 |
+
self._remove_actor(tracked_actor=tracked_actor)
|
| 618 |
+
|
| 619 |
+
def _cleanup_stopping_actors(self, force_all: bool = False):
|
| 620 |
+
now = time.monotonic()
|
| 621 |
+
|
| 622 |
+
if (
|
| 623 |
+
not force_all
|
| 624 |
+
and now - self._earliest_stopping_actor <= self._actor_cleanup_timeout
|
| 625 |
+
):
|
| 626 |
+
# If the earliest actor to timeout has not reached the timeout, return
|
| 627 |
+
return
|
| 628 |
+
|
| 629 |
+
# This is a bit costly, so we want to avoid running it too often
|
| 630 |
+
times = deque(
|
| 631 |
+
sorted(
|
| 632 |
+
[
|
| 633 |
+
(timestamp, tracked_actor)
|
| 634 |
+
for tracked_actor, timestamp in self._stopping_actors.items()
|
| 635 |
+
],
|
| 636 |
+
key=lambda item: item[0],
|
| 637 |
+
)
|
| 638 |
+
)
|
| 639 |
+
|
| 640 |
+
while times and (
|
| 641 |
+
force_all or time.monotonic() - times[0][0] > self._actor_cleanup_timeout
|
| 642 |
+
):
|
| 643 |
+
if (
|
| 644 |
+
time.monotonic() - times[0][0] < self._actor_force_cleanup_timeout
|
| 645 |
+
) and self._actor_manager.is_actor_started(tracked_actor=times[0][1]):
|
| 646 |
+
# Even if force_all=True, we give the actors time to clean up
|
| 647 |
+
self._actor_manager.next(timeout=1)
|
| 648 |
+
continue
|
| 649 |
+
|
| 650 |
+
_, tracked_actor = times.popleft()
|
| 651 |
+
|
| 652 |
+
if tracked_actor not in self._stopping_actors:
|
| 653 |
+
# Actor stopping has been handled by the block above
|
| 654 |
+
continue
|
| 655 |
+
|
| 656 |
+
if self._actor_manager.is_actor_started(tracked_actor=tracked_actor):
|
| 657 |
+
logger.debug(f"Forcefully killing actor: {tracked_actor}")
|
| 658 |
+
self._actor_manager.remove_actor(tracked_actor=tracked_actor, kill=True)
|
| 659 |
+
self._stopping_actors.pop(tracked_actor)
|
| 660 |
+
|
| 661 |
+
if times:
|
| 662 |
+
self._earliest_stopping_actor = times[0][0]
|
| 663 |
+
else:
|
| 664 |
+
self._earliest_stopping_actor = float("inf")
|
| 665 |
+
|
| 666 |
+
def step(self):
|
| 667 |
+
if self.is_finished():
|
| 668 |
+
raise TuneError("Called step when all trials finished?")
|
| 669 |
+
|
| 670 |
+
with warn_if_slow("on_step_begin"):
|
| 671 |
+
self.on_step_begin()
|
| 672 |
+
|
| 673 |
+
with warn_if_slow("callbacks.on_step_begin"):
|
| 674 |
+
self._callbacks.on_step_begin(
|
| 675 |
+
iteration=self._iteration, trials=self._trials
|
| 676 |
+
)
|
| 677 |
+
|
| 678 |
+
# Ask searcher for more trials
|
| 679 |
+
self._maybe_update_trial_queue()
|
| 680 |
+
|
| 681 |
+
# Start actors for added trials
|
| 682 |
+
self._maybe_add_actors()
|
| 683 |
+
|
| 684 |
+
# Handle one event
|
| 685 |
+
if not self._actor_manager.next(timeout=0.1):
|
| 686 |
+
# If there are no actors running, warn about potentially
|
| 687 |
+
# insufficient resources
|
| 688 |
+
if not self._actor_manager.num_live_actors:
|
| 689 |
+
self._insufficient_resources_manager.on_no_available_trials(
|
| 690 |
+
self.get_trials()
|
| 691 |
+
)
|
| 692 |
+
|
| 693 |
+
# Maybe stop whole experiment
|
| 694 |
+
self._stop_experiment_if_needed()
|
| 695 |
+
|
| 696 |
+
# Maybe save experiment state
|
| 697 |
+
try:
|
| 698 |
+
self.checkpoint()
|
| 699 |
+
except Exception as e:
|
| 700 |
+
logger.warning(f"Trial controller checkpointing failed: {str(e)}")
|
| 701 |
+
raise e
|
| 702 |
+
|
| 703 |
+
self._iteration += 1
|
| 704 |
+
|
| 705 |
+
with warn_if_slow("on_step_end"):
|
| 706 |
+
self.on_step_end()
|
| 707 |
+
with warn_if_slow("callbacks.on_step_end"):
|
| 708 |
+
self._callbacks.on_step_end(iteration=self._iteration, trials=self._trials)
|
| 709 |
+
|
| 710 |
+
def _set_trial_status(self, trial: Trial, status: str):
|
| 711 |
+
"""Set trial to a specific status.
|
| 712 |
+
|
| 713 |
+
This will keep track of trials with specific statuses in sets.
|
| 714 |
+
|
| 715 |
+
For PENDING and PAUSED trials we also keep a list of trials to be able
|
| 716 |
+
to retain FIFO ordering. See ``_maybe_add_actors`` for details.
|
| 717 |
+
|
| 718 |
+
Lastly we also keep a mapping from resources to pending/paused trials
|
| 719 |
+
to be able to efficiently start trials for cached actors.
|
| 720 |
+
"""
|
| 721 |
+
current_status = trial.status
|
| 722 |
+
|
| 723 |
+
if current_status == status:
|
| 724 |
+
logger.debug(f"Trial {trial} already has status {status}. Skipping update.")
|
| 725 |
+
return
|
| 726 |
+
|
| 727 |
+
status_str_map = {
|
| 728 |
+
Trial.PENDING: self._pending_trials,
|
| 729 |
+
Trial.RUNNING: self._running_trials,
|
| 730 |
+
Trial.PAUSED: self._paused_trials,
|
| 731 |
+
Trial.TERMINATED: self._stopped_trials,
|
| 732 |
+
Trial.ERROR: self._failed_trials,
|
| 733 |
+
}
|
| 734 |
+
|
| 735 |
+
logger.debug(
|
| 736 |
+
f"Setting status for trial {trial} from {current_status} to {status}"
|
| 737 |
+
)
|
| 738 |
+
|
| 739 |
+
assert trial in status_str_map[current_status], (trial, current_status)
|
| 740 |
+
assert trial not in status_str_map[status], (trial, status)
|
| 741 |
+
|
| 742 |
+
status_str_map[current_status].remove(trial)
|
| 743 |
+
status_str_map[status].add(trial)
|
| 744 |
+
|
| 745 |
+
# We keep a log for pending trials for FIFO scheduling.
|
| 746 |
+
# We do not need to remove from this list as we will just discard
|
| 747 |
+
# items that are in this list but not in the respective set.
|
| 748 |
+
if status == Trial.PENDING:
|
| 749 |
+
self._pending_trials_list.append(trial)
|
| 750 |
+
self._resources_to_pending_trials[trial.placement_group_factory].add(trial)
|
| 751 |
+
else:
|
| 752 |
+
self._resources_to_pending_trials[trial.placement_group_factory].discard(
|
| 753 |
+
trial
|
| 754 |
+
)
|
| 755 |
+
|
| 756 |
+
trial.set_status(status)
|
| 757 |
+
|
| 758 |
+
def _get_trial_checkpoints(self) -> Dict[str, str]:
|
| 759 |
+
for trial in self._trials_to_cache:
|
| 760 |
+
self._trial_metadata[trial.trial_id] = trial.get_json_state()
|
| 761 |
+
self._trials_to_cache.clear()
|
| 762 |
+
return self._trial_metadata
|
| 763 |
+
|
| 764 |
+
def _mark_trial_to_checkpoint(self, trial: Trial):
|
| 765 |
+
self._trials_to_cache.add(trial)
|
| 766 |
+
|
| 767 |
+
###
|
| 768 |
+
# UPDATE TRIALS
|
| 769 |
+
def _maybe_update_trial_queue(self):
|
| 770 |
+
"""Ask the searcher for more trials."""
|
| 771 |
+
if self._search_alg.is_finished():
|
| 772 |
+
return
|
| 773 |
+
|
| 774 |
+
dont_wait_for_trial = (
|
| 775 |
+
self._pending_trials or self._running_trials or self._paused_trials
|
| 776 |
+
)
|
| 777 |
+
|
| 778 |
+
while len(self._pending_trials) < self._max_pending_trials:
|
| 779 |
+
if not self._update_trial_queue(blocking=not dont_wait_for_trial):
|
| 780 |
+
break
|
| 781 |
+
dont_wait_for_trial = True
|
| 782 |
+
|
| 783 |
+
def _cleanup_trials(self):
|
| 784 |
+
logger.debug("CLEANING UP all trials")
|
| 785 |
+
|
| 786 |
+
for tracked_actor in list(self._actor_to_trial):
|
| 787 |
+
trial = self._actor_to_trial[tracked_actor]
|
| 788 |
+
logger.debug(
|
| 789 |
+
f"Scheduling trial stop at end of experiment (trial {trial}): "
|
| 790 |
+
f"{tracked_actor}"
|
| 791 |
+
)
|
| 792 |
+
self._schedule_trial_stop(trial)
|
| 793 |
+
|
| 794 |
+
# Clean up cached actors now
|
| 795 |
+
self._cleanup_cached_actors(force_all=True)
|
| 796 |
+
|
| 797 |
+
start = time.monotonic()
|
| 798 |
+
while time.monotonic() - start < 5 and self._actor_manager.num_total_actors:
|
| 799 |
+
if _dedup_logs("actor_manager_cleanup", str(start)):
|
| 800 |
+
logger.debug(
|
| 801 |
+
"Waiting for actor manager to clean up final state [dedup]"
|
| 802 |
+
)
|
| 803 |
+
self._actor_manager.next(timeout=1)
|
| 804 |
+
|
| 805 |
+
logger.debug("Force cleanup of remaining actors")
|
| 806 |
+
self._cleanup_stopping_actors(force_all=True)
|
| 807 |
+
|
| 808 |
+
self._actor_manager.cleanup()
|
| 809 |
+
|
| 810 |
+
def _remove_actor(self, tracked_actor: TrackedActor):
|
| 811 |
+
stop_future = self._actor_manager.schedule_actor_task(
|
| 812 |
+
tracked_actor, "stop", _return_future=True
|
| 813 |
+
)
|
| 814 |
+
now = time.monotonic()
|
| 815 |
+
|
| 816 |
+
if self._actor_manager.remove_actor(
|
| 817 |
+
tracked_actor, kill=False, stop_future=stop_future
|
| 818 |
+
):
|
| 819 |
+
# If the actor was previously alive, track
|
| 820 |
+
self._stopping_actors[tracked_actor] = now
|
| 821 |
+
self._earliest_stopping_actor = min(self._earliest_stopping_actor, now)
|
| 822 |
+
|
| 823 |
+
###
|
| 824 |
+
# ADD ACTORS
|
| 825 |
+
def _maybe_add_actors(self) -> None:
|
| 826 |
+
"""Add actors for pending and paused trials.
|
| 827 |
+
|
| 828 |
+
For actors that have not been staged, yet, we request an actor.
|
| 829 |
+
|
| 830 |
+
For actors that have been staged, already, we try to reuse a cached actor.
|
| 831 |
+
|
| 832 |
+
First, we handle the trial that the scheduler chooses to run.
|
| 833 |
+
|
| 834 |
+
Then, we handle all trials that are pending.
|
| 835 |
+
|
| 836 |
+
Lastly, we see if we have cached actors that we can assign to a pending or
|
| 837 |
+
paused trial. This can be the case when a trial has not been staged, yet,
|
| 838 |
+
for instance because the number of staging trials was too large.
|
| 839 |
+
"""
|
| 840 |
+
|
| 841 |
+
###
|
| 842 |
+
# 1: Start trial that the scheduler wants to run
|
| 843 |
+
with warn_if_slow("choose_trial_to_run"):
|
| 844 |
+
trial_to_run = self._scheduler_alg.choose_trial_to_run(self._wrapped())
|
| 845 |
+
|
| 846 |
+
if trial_to_run:
|
| 847 |
+
if _dedup_logs("trial_to_run_chosen", trial_to_run.trial_id):
|
| 848 |
+
logger.debug(
|
| 849 |
+
f"Chose trial to run from scheduler: {trial_to_run} [dedup]"
|
| 850 |
+
)
|
| 851 |
+
if (
|
| 852 |
+
trial_to_run not in self._staged_trials
|
| 853 |
+
and trial_to_run not in self._trial_to_actor
|
| 854 |
+
):
|
| 855 |
+
logger.debug(f"Staging trial to run: {trial_to_run}")
|
| 856 |
+
self._set_trial_status(trial_to_run, Trial.PENDING)
|
| 857 |
+
self._staged_trials.add(trial_to_run)
|
| 858 |
+
self._actor_cache.increase_max(trial_to_run.placement_group_factory)
|
| 859 |
+
# schedule_trial_actor also potentially uses cached actors
|
| 860 |
+
self._schedule_trial_actor(trial_to_run)
|
| 861 |
+
else:
|
| 862 |
+
# Otherwise, only try to use the cached actor
|
| 863 |
+
if _dedup_logs("trial_to_run_reuse", trial_to_run.trial_id):
|
| 864 |
+
logger.debug(
|
| 865 |
+
f"Trying to re-use actor for trial to run: {trial_to_run} "
|
| 866 |
+
f"[dedup]"
|
| 867 |
+
)
|
| 868 |
+
self._maybe_reuse_cached_actor(trial_to_run)
|
| 869 |
+
|
| 870 |
+
###
|
| 871 |
+
# 2: Start trials that are PENDING
|
| 872 |
+
def _maybe_add_actors(candidates: List[Trial]):
|
| 873 |
+
new_candidates = []
|
| 874 |
+
|
| 875 |
+
while candidates:
|
| 876 |
+
if self._actor_manager.num_pending_actors >= self._max_pending_trials:
|
| 877 |
+
break
|
| 878 |
+
|
| 879 |
+
trial = candidates.pop(0)
|
| 880 |
+
|
| 881 |
+
# If the trial is part of the list, but not of the set,
|
| 882 |
+
# we just ignore it. Removing it from the list on status
|
| 883 |
+
# change is too expensive.
|
| 884 |
+
if trial not in self._pending_trials:
|
| 885 |
+
continue
|
| 886 |
+
|
| 887 |
+
if trial in self._trial_to_actor:
|
| 888 |
+
new_candidates.append(trial)
|
| 889 |
+
continue
|
| 890 |
+
|
| 891 |
+
if trial in self._staged_trials:
|
| 892 |
+
self._maybe_reuse_cached_actor(trial)
|
| 893 |
+
continue
|
| 894 |
+
|
| 895 |
+
logger.debug(f"Scheduling actor for enqueued trial: {trial}")
|
| 896 |
+
self._staged_trials.add(trial)
|
| 897 |
+
self._actor_cache.increase_max(trial.placement_group_factory)
|
| 898 |
+
self._schedule_trial_actor(trial)
|
| 899 |
+
|
| 900 |
+
return new_candidates + candidates
|
| 901 |
+
|
| 902 |
+
self._pending_trials_list = _maybe_add_actors(self._pending_trials_list)
|
| 903 |
+
|
| 904 |
+
###
|
| 905 |
+
# 3: Start any trial that can be started with a cached actor
|
| 906 |
+
if self._actor_cache.num_cached_objects:
|
| 907 |
+
for resource in self._resources_to_pending_trials:
|
| 908 |
+
if not self._resources_to_pending_trials[resource]:
|
| 909 |
+
continue
|
| 910 |
+
|
| 911 |
+
if not self._actor_cache.has_cached_object(resource):
|
| 912 |
+
continue
|
| 913 |
+
|
| 914 |
+
start_trial = self._resources_to_pending_trials[resource].pop()
|
| 915 |
+
logger.debug(
|
| 916 |
+
f"Trying to re-use actor for enqueued trial: {start_trial}"
|
| 917 |
+
)
|
| 918 |
+
if not self._maybe_reuse_cached_actor(start_trial):
|
| 919 |
+
self._resources_to_pending_trials[resource].add(start_trial)
|
| 920 |
+
else:
|
| 921 |
+
if start_trial not in self._staged_trials:
|
| 922 |
+
self._staged_trials.add(start_trial)
|
| 923 |
+
self._actor_cache.increase_max(
|
| 924 |
+
start_trial.placement_group_factory
|
| 925 |
+
)
|
| 926 |
+
|
| 927 |
+
def _maybe_reuse_cached_actor(self, trial: Trial) -> bool:
|
| 928 |
+
"""Maybe reuse a cached actor for a trial.
|
| 929 |
+
|
| 930 |
+
If an actor has been scheduled for the trial already,
|
| 931 |
+
this will remove the original actor.
|
| 932 |
+
"""
|
| 933 |
+
if trial in self._resetting_trials:
|
| 934 |
+
return True
|
| 935 |
+
|
| 936 |
+
resource_request = trial.placement_group_factory
|
| 937 |
+
|
| 938 |
+
if not self._actor_cache.has_cached_object(resource_request):
|
| 939 |
+
return False
|
| 940 |
+
|
| 941 |
+
cached_actor = self._actor_cache.pop_cached_object(resource_request)
|
| 942 |
+
logger.debug(f"Reusing ACTOR for trial {trial}: {cached_actor}")
|
| 943 |
+
|
| 944 |
+
if trial in self._trial_to_actor:
|
| 945 |
+
original_actor = self._trial_to_actor.pop(trial)
|
| 946 |
+
self._actor_to_trial.pop(original_actor)
|
| 947 |
+
|
| 948 |
+
logger.debug(f"Removing ORIGINAL ACTOR for trial {trial}: {original_actor}")
|
| 949 |
+
self._remove_actor(tracked_actor=original_actor)
|
| 950 |
+
|
| 951 |
+
self._trial_to_actor[trial] = cached_actor
|
| 952 |
+
self._actor_to_trial[cached_actor] = trial
|
| 953 |
+
|
| 954 |
+
# Todo: get rid of Trial.runner
|
| 955 |
+
ray_actor = self._actor_manager._live_actors_to_ray_actors_resources[
|
| 956 |
+
cached_actor
|
| 957 |
+
][0]
|
| 958 |
+
trial.set_ray_actor(ray_actor)
|
| 959 |
+
|
| 960 |
+
self._schedule_trial_reset(trial, trial.config, trial.experiment_tag)
|
| 961 |
+
|
| 962 |
+
return True
|
| 963 |
+
|
| 964 |
+
def _schedule_trial_actor(self, trial: Trial):
|
| 965 |
+
"""Schedule an actor for a trial.
|
| 966 |
+
|
| 967 |
+
If a cached actor is available, use it. Otherwise, request a
|
| 968 |
+
new actor.
|
| 969 |
+
"""
|
| 970 |
+
logger.debug(f"Trying to schedule new ACTOR for trial {trial}")
|
| 971 |
+
|
| 972 |
+
assert trial.status == Trial.PENDING
|
| 973 |
+
|
| 974 |
+
trial.init_local_path()
|
| 975 |
+
# We checkpoint metadata here to try mitigating logdir duplication
|
| 976 |
+
self._mark_trial_to_checkpoint(trial)
|
| 977 |
+
|
| 978 |
+
if self._maybe_reuse_cached_actor(trial):
|
| 979 |
+
return
|
| 980 |
+
|
| 981 |
+
# Safeguard
|
| 982 |
+
if trial in self._trial_to_actor:
|
| 983 |
+
raise RuntimeError(
|
| 984 |
+
f"Tried to request a new actor for trial {trial}, but an old "
|
| 985 |
+
f"actor still exists. This can lead to leaked resources. The old "
|
| 986 |
+
f"actor should be removed first. "
|
| 987 |
+
f"This is an internal problem in Ray Tune. If you encounter this "
|
| 988 |
+
f"error, please raise an issue on "
|
| 989 |
+
f"https://github.com/ray-project/ray/issues"
|
| 990 |
+
)
|
| 991 |
+
|
| 992 |
+
trainable_cls = trial.get_trainable_cls()
|
| 993 |
+
if not trainable_cls:
|
| 994 |
+
exception = _AbortTrialExecution(
|
| 995 |
+
f"Invalid trainable: {trial.trainable_name}. If you passed "
|
| 996 |
+
f"a string, make sure the trainable was registered before."
|
| 997 |
+
)
|
| 998 |
+
trial.handle_error(exception)
|
| 999 |
+
self._schedule_trial_stop(trial, exception=exception)
|
| 1000 |
+
return
|
| 1001 |
+
|
| 1002 |
+
_actor_cls = self._class_cache.get(trainable_cls)
|
| 1003 |
+
|
| 1004 |
+
trial.set_location(_Location())
|
| 1005 |
+
trainable_kwargs = _get_trainable_kwargs(trial=trial)
|
| 1006 |
+
|
| 1007 |
+
with _change_working_directory(trial):
|
| 1008 |
+
tracked_actor = self._actor_manager.add_actor(
|
| 1009 |
+
cls=_actor_cls,
|
| 1010 |
+
resource_request=trial.placement_group_factory,
|
| 1011 |
+
kwargs=trainable_kwargs,
|
| 1012 |
+
on_start=self._actor_started,
|
| 1013 |
+
on_stop=self._actor_stopped,
|
| 1014 |
+
on_error=self._actor_failed,
|
| 1015 |
+
)
|
| 1016 |
+
self._trial_to_actor[trial] = tracked_actor
|
| 1017 |
+
self._actor_to_trial[tracked_actor] = trial
|
| 1018 |
+
|
| 1019 |
+
logger.debug(
|
| 1020 |
+
f"Scheduled new ACTOR for trial {trial}: {tracked_actor}. "
|
| 1021 |
+
f"Resources: {trial.placement_group_factory}"
|
| 1022 |
+
)
|
| 1023 |
+
|
| 1024 |
+
def _unstage_trial_with_resources(self, trial: Trial):
|
| 1025 |
+
"""Unstage trial, or one with the same resources as ``trial``."""
|
| 1026 |
+
# Case 1: The trial we started was staged. Just remove it
|
| 1027 |
+
if trial in self._staged_trials:
|
| 1028 |
+
self._staged_trials.remove(trial)
|
| 1029 |
+
self._actor_cache.decrease_max(trial.placement_group_factory)
|
| 1030 |
+
return
|
| 1031 |
+
|
| 1032 |
+
# Case 2: We staged a trial "A" with the same resources, but our trial "B"
|
| 1033 |
+
# was selected by the scheduler to run. The resource manager does not care
|
| 1034 |
+
# about "trials", it just cares about resources being available. Thus we
|
| 1035 |
+
# look for a staged trial with the same resource requirements and remove it
|
| 1036 |
+
|
| 1037 |
+
resource_request = trial.placement_group_factory
|
| 1038 |
+
# Remove staged trial with same resource requirements
|
| 1039 |
+
candidate_trial = None
|
| 1040 |
+
for staged_trial in self._staged_trials:
|
| 1041 |
+
staged_resources = staged_trial.placement_group_factory
|
| 1042 |
+
if staged_resources == resource_request:
|
| 1043 |
+
candidate_trial = staged_trial
|
| 1044 |
+
break
|
| 1045 |
+
|
| 1046 |
+
if candidate_trial:
|
| 1047 |
+
self._staged_trials.remove(candidate_trial)
|
| 1048 |
+
self._actor_cache.decrease_max(candidate_trial.placement_group_factory)
|
| 1049 |
+
return
|
| 1050 |
+
|
| 1051 |
+
raise RuntimeError(
|
| 1052 |
+
"Started a trial with resources requested by a different trial, but "
|
| 1053 |
+
"this trial was lost. This is an error in Ray Tune's execution "
|
| 1054 |
+
"logic. Please raise a GitHub issue at "
|
| 1055 |
+
"https://github.com/ray-project/ray/issues"
|
| 1056 |
+
)
|
| 1057 |
+
|
| 1058 |
+
def _maybe_cache_trial_actor(self, trial: Trial) -> bool:
|
| 1059 |
+
"""Cache trial actor for reuse, if needed.
|
| 1060 |
+
|
| 1061 |
+
We will only cache as many actors as are needed to fulfill any pending
|
| 1062 |
+
resource requests for actors with the same resource requirements.
|
| 1063 |
+
E.g. if we have 6 running trials and 4 additional staged actors, we will only
|
| 1064 |
+
cache up to 4 of the running trial actors when they finish.
|
| 1065 |
+
|
| 1066 |
+
One exception is the case when we have no cached actors, yet. In that case,
|
| 1067 |
+
we will always cache the actor in this method.
|
| 1068 |
+
|
| 1069 |
+
Later, in `_cleanup_cached_actors`, we will check again if we need this cached
|
| 1070 |
+
actor. That method will keep the actor if we don't have any staged trials,
|
| 1071 |
+
because we don't know at that point if the next trial might require the same
|
| 1072 |
+
resources. But because there is no staged trial, it is safe to keep the actor
|
| 1073 |
+
around, as it won't occupy resources needed by another trial until it's staged.
|
| 1074 |
+
"""
|
| 1075 |
+
if not self._reuse_actors:
|
| 1076 |
+
return False
|
| 1077 |
+
|
| 1078 |
+
if self._search_alg.is_finished() and not self._staged_trials:
|
| 1079 |
+
logger.debug(
|
| 1080 |
+
f"Not caching actor of trial {trial} as the search is over "
|
| 1081 |
+
f"and no more trials are staged."
|
| 1082 |
+
)
|
| 1083 |
+
return False
|
| 1084 |
+
|
| 1085 |
+
tracked_actor = self._trial_to_actor[trial]
|
| 1086 |
+
|
| 1087 |
+
if (
|
| 1088 |
+
not self._actor_manager.is_actor_started(tracked_actor)
|
| 1089 |
+
or self._actor_manager.is_actor_failed(tracked_actor)
|
| 1090 |
+
or tracked_actor not in self._started_actors
|
| 1091 |
+
):
|
| 1092 |
+
logger.debug(
|
| 1093 |
+
f"Not caching actor of trial {trial} as it has not been started, yet: "
|
| 1094 |
+
f"{tracked_actor}"
|
| 1095 |
+
)
|
| 1096 |
+
return False
|
| 1097 |
+
|
| 1098 |
+
if not self._actor_cache.cache_object(
|
| 1099 |
+
trial.placement_group_factory, tracked_actor
|
| 1100 |
+
):
|
| 1101 |
+
logger.debug(
|
| 1102 |
+
f"Could not cache actor of trial {trial} for "
|
| 1103 |
+
"reuse, as there are no pending trials "
|
| 1104 |
+
"requiring its resources."
|
| 1105 |
+
)
|
| 1106 |
+
return False
|
| 1107 |
+
|
| 1108 |
+
logger.debug(f"Caching actor of trial {trial} for re-use: {tracked_actor}")
|
| 1109 |
+
|
| 1110 |
+
tracked_actor = self._trial_to_actor.pop(trial)
|
| 1111 |
+
self._actor_to_trial.pop(tracked_actor)
|
| 1112 |
+
|
| 1113 |
+
trial.set_ray_actor(None)
|
| 1114 |
+
|
| 1115 |
+
return True
|
| 1116 |
+
|
| 1117 |
+
def _actor_started(self, tracked_actor: TrackedActor, log: str = "STARTED"):
|
| 1118 |
+
self._started_actors.add(tracked_actor)
|
| 1119 |
+
|
| 1120 |
+
trial = self._actor_to_trial[tracked_actor]
|
| 1121 |
+
|
| 1122 |
+
logger.debug(f"Actor {log} for trial {trial}: {tracked_actor}")
|
| 1123 |
+
|
| 1124 |
+
self._unstage_trial_with_resources(trial)
|
| 1125 |
+
|
| 1126 |
+
ray_actor = self._actor_manager._live_actors_to_ray_actors_resources[
|
| 1127 |
+
tracked_actor
|
| 1128 |
+
][0]
|
| 1129 |
+
trial.set_ray_actor(ray_actor)
|
| 1130 |
+
|
| 1131 |
+
self._callbacks.on_trial_start(
|
| 1132 |
+
iteration=self._iteration, trials=self._trials, trial=trial
|
| 1133 |
+
)
|
| 1134 |
+
|
| 1135 |
+
self._set_trial_status(trial, Trial.RUNNING)
|
| 1136 |
+
|
| 1137 |
+
self._mark_trial_to_checkpoint(trial)
|
| 1138 |
+
|
| 1139 |
+
if not self._schedule_trial_restore(trial):
|
| 1140 |
+
self._schedule_trial_train(trial)
|
| 1141 |
+
|
| 1142 |
+
def _actor_stopped(self, tracked_actor: TrackedActor):
|
| 1143 |
+
if tracked_actor in self._actor_to_trial:
|
| 1144 |
+
trial = self._actor_to_trial.pop(tracked_actor)
|
| 1145 |
+
logger.debug(f"Actor STOPPED for trial {trial}: {tracked_actor}")
|
| 1146 |
+
self._trial_to_actor.pop(trial)
|
| 1147 |
+
trial.set_ray_actor(None)
|
| 1148 |
+
|
| 1149 |
+
logger.debug(f"Actor STOPPED: {tracked_actor}")
|
| 1150 |
+
|
| 1151 |
+
self._stopping_actors.pop(tracked_actor, None)
|
| 1152 |
+
self._started_actors.discard(tracked_actor)
|
| 1153 |
+
|
| 1154 |
+
def _actor_failed(self, tracked_actor: TrackedActor, exception: Exception):
|
| 1155 |
+
trial = self._actor_to_trial[tracked_actor]
|
| 1156 |
+
|
| 1157 |
+
logger.debug(
|
| 1158 |
+
f"Actor FAILED for trial {trial}: {tracked_actor}. "
|
| 1159 |
+
f"Exception: {exception}"
|
| 1160 |
+
)
|
| 1161 |
+
|
| 1162 |
+
if trial in (self._pending_trials | self._paused_trials):
|
| 1163 |
+
# First, set to running (needed downstream in _process_trial_failure)
|
| 1164 |
+
self._set_trial_status(trial, Trial.RUNNING)
|
| 1165 |
+
|
| 1166 |
+
logger.debug(
|
| 1167 |
+
f"Trial {trial} failed in its creation task. Unstaging "
|
| 1168 |
+
f"to allow it to be re-scheduled."
|
| 1169 |
+
)
|
| 1170 |
+
|
| 1171 |
+
self._unstage_trial_with_resources(trial)
|
| 1172 |
+
self._trial_task_failure(trial, exception=exception)
|
| 1173 |
+
|
| 1174 |
+
self._actor_manager.clear_actor_task_futures(tracked_actor)
|
| 1175 |
+
|
| 1176 |
+
# Clean up actor
|
| 1177 |
+
tracked_actor.set_on_stop(None)
|
| 1178 |
+
tracked_actor.set_on_error(None)
|
| 1179 |
+
self._actor_manager.remove_actor(tracked_actor, kill=False)
|
| 1180 |
+
|
| 1181 |
+
# Trigger actor stopped callback
|
| 1182 |
+
self._actor_stopped(tracked_actor)
|
| 1183 |
+
|
| 1184 |
+
def _schedule_trial_task(
|
| 1185 |
+
self,
|
| 1186 |
+
trial: Trial,
|
| 1187 |
+
method_name: str,
|
| 1188 |
+
args: Optional[Tuple] = None,
|
| 1189 |
+
kwargs: Optional[Dict] = None,
|
| 1190 |
+
on_result: Optional[Callable[[Trial, Any], None]] = None,
|
| 1191 |
+
on_error: Optional[Callable[[Trial, Exception], None]] = None,
|
| 1192 |
+
_return_future: bool = False,
|
| 1193 |
+
) -> Optional[ray.ObjectRef]:
|
| 1194 |
+
"""Schedule an actor task future for a trial.
|
| 1195 |
+
|
| 1196 |
+
This is a wrapper around ``ActorManager.schedule_actor_task``. This method
|
| 1197 |
+
retrieves the tracked actor for a trial to kick off the task.
|
| 1198 |
+
|
| 1199 |
+
It also wraps around the callbacks, retrieving the trial object given the
|
| 1200 |
+
tracked actor.
|
| 1201 |
+
"""
|
| 1202 |
+
|
| 1203 |
+
tracked_actor = self._trial_to_actor[trial]
|
| 1204 |
+
|
| 1205 |
+
_on_result = None
|
| 1206 |
+
_on_error = None
|
| 1207 |
+
|
| 1208 |
+
args = args or tuple()
|
| 1209 |
+
kwargs = kwargs or {}
|
| 1210 |
+
|
| 1211 |
+
if on_result:
|
| 1212 |
+
|
| 1213 |
+
def _on_result(tracked_actor: TrackedActor, *args, **kwargs):
|
| 1214 |
+
assert trial == self._actor_to_trial[tracked_actor]
|
| 1215 |
+
logger.debug(
|
| 1216 |
+
f"Future {method_name.upper()} RESOLVED for trial {trial}: "
|
| 1217 |
+
f"{args}, {kwargs}"
|
| 1218 |
+
)
|
| 1219 |
+
try:
|
| 1220 |
+
on_result(trial, *args, **kwargs)
|
| 1221 |
+
except Exception as e:
|
| 1222 |
+
logger.debug(
|
| 1223 |
+
f"Error handling {method_name.upper()} result "
|
| 1224 |
+
f"for trial {trial}: {e}"
|
| 1225 |
+
)
|
| 1226 |
+
if e is TuneError or self._fail_fast == self.RAISE:
|
| 1227 |
+
raise e
|
| 1228 |
+
else:
|
| 1229 |
+
raise TuneError(traceback.format_exc())
|
| 1230 |
+
|
| 1231 |
+
if on_error:
|
| 1232 |
+
|
| 1233 |
+
def _on_error(tracked_actor: TrackedActor, exception: Exception):
|
| 1234 |
+
# If the actor failed, it has already been cleaned up.
|
| 1235 |
+
if tracked_actor not in self._actor_to_trial:
|
| 1236 |
+
assert isinstance(exception, RayActorError), type(exception)
|
| 1237 |
+
else:
|
| 1238 |
+
assert trial == self._actor_to_trial[tracked_actor]
|
| 1239 |
+
|
| 1240 |
+
logger.debug(
|
| 1241 |
+
f"Future {method_name.upper()} FAILED for trial {trial}: "
|
| 1242 |
+
f"{exception}"
|
| 1243 |
+
)
|
| 1244 |
+
try:
|
| 1245 |
+
on_error(trial, exception)
|
| 1246 |
+
except Exception as e:
|
| 1247 |
+
logger.debug(
|
| 1248 |
+
f"Error handling {method_name.upper()} failure "
|
| 1249 |
+
f"for trial {trial}: {e}"
|
| 1250 |
+
)
|
| 1251 |
+
if e is TuneError or self._fail_fast == self.RAISE:
|
| 1252 |
+
raise e
|
| 1253 |
+
else:
|
| 1254 |
+
raise TuneError(traceback.format_exc())
|
| 1255 |
+
|
| 1256 |
+
logger.debug(f"Future {method_name.upper()} SCHEDULED for trial {trial}")
|
| 1257 |
+
|
| 1258 |
+
with _change_working_directory(trial):
|
| 1259 |
+
future = self._actor_manager.schedule_actor_task(
|
| 1260 |
+
tracked_actor=tracked_actor,
|
| 1261 |
+
method_name=method_name,
|
| 1262 |
+
args=args,
|
| 1263 |
+
kwargs=kwargs,
|
| 1264 |
+
on_result=_on_result,
|
| 1265 |
+
on_error=_on_error,
|
| 1266 |
+
_return_future=_return_future,
|
| 1267 |
+
)
|
| 1268 |
+
if _return_future:
|
| 1269 |
+
return future
|
| 1270 |
+
|
| 1271 |
+
def _queue_decision(self, trial, decision):
|
| 1272 |
+
# Get old decision, setting it to the current decision if it isn't set
|
| 1273 |
+
old_decision = self._queued_trial_decisions.setdefault(trial.trial_id, decision)
|
| 1274 |
+
|
| 1275 |
+
# Stopping always takes precedence. If we decided to stop, just quit
|
| 1276 |
+
if old_decision is TrialScheduler.STOP:
|
| 1277 |
+
return
|
| 1278 |
+
|
| 1279 |
+
# The old decision wasn't STOP. We update the decision only if it is
|
| 1280 |
+
# STOP or PAUSE. The action will only be CONTINUE if it was set by
|
| 1281 |
+
# the first received result and was never updated after that.
|
| 1282 |
+
if decision is TrialScheduler.STOP or decision is TrialScheduler.PAUSE:
|
| 1283 |
+
self._queued_trial_decisions[trial.trial_id] = decision
|
| 1284 |
+
|
| 1285 |
+
def _execute_action(self, trial: Trial, decision: str, after_save: bool = False):
|
| 1286 |
+
"""Executes action based on decision.
|
| 1287 |
+
|
| 1288 |
+
Args:
|
| 1289 |
+
trial: Trial to act on.
|
| 1290 |
+
decision: Scheduling decision to undertake.
|
| 1291 |
+
"""
|
| 1292 |
+
if decision == TrialScheduler.CONTINUE:
|
| 1293 |
+
self._schedule_trial_train(trial)
|
| 1294 |
+
elif decision == TrialScheduler.PAUSE:
|
| 1295 |
+
self.pause_trial(trial, should_checkpoint=not after_save)
|
| 1296 |
+
elif decision == TrialScheduler.STOP:
|
| 1297 |
+
self.stop_trial(trial)
|
| 1298 |
+
elif decision == TrialScheduler.NOOP:
|
| 1299 |
+
pass
|
| 1300 |
+
else:
|
| 1301 |
+
raise ValueError("Invalid decision: {}".format(decision))
|
| 1302 |
+
|
| 1303 |
+
def _maybe_execute_queued_decision(self, trial: Trial, after_save: bool = False):
|
| 1304 |
+
# `self._queued_trial_decisions` now contains a final decision
|
| 1305 |
+
# based on all results
|
| 1306 |
+
final_decision = self._queued_trial_decisions.pop(trial.trial_id, None)
|
| 1307 |
+
if final_decision:
|
| 1308 |
+
logger.debug(
|
| 1309 |
+
f"Executing final queued decision for {trial}: {final_decision}"
|
| 1310 |
+
)
|
| 1311 |
+
self._execute_action(trial, final_decision, after_save=after_save)
|
| 1312 |
+
|
| 1313 |
+
def _stop_experiment_if_needed(self):
|
| 1314 |
+
"""Stops all trials."""
|
| 1315 |
+
fail_fast = self._fail_fast and self._has_errored
|
| 1316 |
+
if self._stopper.stop_all() or fail_fast or self._should_stop_experiment:
|
| 1317 |
+
self._search_alg.set_finished()
|
| 1318 |
+
[
|
| 1319 |
+
self._schedule_trial_stop(t)
|
| 1320 |
+
for t in self._trials
|
| 1321 |
+
if t.status not in {Trial.ERROR, Trial.TERMINATED}
|
| 1322 |
+
]
|
| 1323 |
+
|
| 1324 |
+
###
|
| 1325 |
+
# Failure
|
| 1326 |
+
def _trial_task_failure(self, trial: Trial, exception: Exception):
|
| 1327 |
+
if self._fail_fast == self.RAISE:
|
| 1328 |
+
raise exception
|
| 1329 |
+
else:
|
| 1330 |
+
if self._print_trial_errors:
|
| 1331 |
+
logger.error(f"Trial task failed for trial {trial}", exc_info=exception)
|
| 1332 |
+
self._process_trial_failure(trial, exception=exception)
|
| 1333 |
+
|
| 1334 |
+
def _process_trial_failure(
|
| 1335 |
+
self,
|
| 1336 |
+
trial: Trial,
|
| 1337 |
+
exception: Union[TuneError, RayTaskError, RayActorError],
|
| 1338 |
+
):
|
| 1339 |
+
"""Handle trial failure.
|
| 1340 |
+
|
| 1341 |
+
Attempt trial recovery if possible, clean up state otherwise.
|
| 1342 |
+
|
| 1343 |
+
Args:
|
| 1344 |
+
trial: Failed trial.
|
| 1345 |
+
exception: Exception prior to invoking this method.
|
| 1346 |
+
"""
|
| 1347 |
+
self._has_errored = True
|
| 1348 |
+
trial.handle_error(exception)
|
| 1349 |
+
if trial.status == Trial.RUNNING and trial.should_recover():
|
| 1350 |
+
self._try_recover(trial, exc=exception)
|
| 1351 |
+
self._callbacks.on_trial_recover(
|
| 1352 |
+
iteration=self._iteration, trials=self._trials, trial=trial
|
| 1353 |
+
)
|
| 1354 |
+
elif trial.status in {Trial.RUNNING, Trial.PENDING}:
|
| 1355 |
+
self._scheduler_alg.on_trial_error(self, trial)
|
| 1356 |
+
self._search_alg.on_trial_complete(trial.trial_id, error=True)
|
| 1357 |
+
self._schedule_trial_stop(trial, exception=exception)
|
| 1358 |
+
self._callbacks.on_trial_error(
|
| 1359 |
+
iteration=self._iteration, trials=self._trials, trial=trial
|
| 1360 |
+
)
|
| 1361 |
+
|
| 1362 |
+
def _schedule_trial_stop(self, trial: Trial, exception: Optional[Exception] = None):
|
| 1363 |
+
if trial.status == Trial.ERROR:
|
| 1364 |
+
logger.debug(f"Not requesting trial STOP as it is ERROR already: {trial}")
|
| 1365 |
+
return
|
| 1366 |
+
|
| 1367 |
+
logger.debug(f"Requesting to STOP actor for trial {trial}")
|
| 1368 |
+
|
| 1369 |
+
if trial.is_saving:
|
| 1370 |
+
logger.debug(
|
| 1371 |
+
f"Trial {trial} is currently saving/pausing. Scheduling STOP after "
|
| 1372 |
+
f"save resolved."
|
| 1373 |
+
)
|
| 1374 |
+
self._cached_trial_decisions[trial.trial_id] = TrialScheduler.STOP
|
| 1375 |
+
|
| 1376 |
+
trial.temporary_state.saving_to = None
|
| 1377 |
+
trial.temporary_state.restoring_from = None
|
| 1378 |
+
|
| 1379 |
+
self._set_trial_status(trial, Trial.ERROR if exception else Trial.TERMINATED)
|
| 1380 |
+
trial.set_location(_Location())
|
| 1381 |
+
|
| 1382 |
+
if trial not in self._trial_to_actor:
|
| 1383 |
+
logger.debug(f"Will not STOP trial actor as it is not live: {trial}")
|
| 1384 |
+
return
|
| 1385 |
+
|
| 1386 |
+
tracked_actor = self._trial_to_actor[trial]
|
| 1387 |
+
|
| 1388 |
+
self._actor_manager.clear_actor_task_futures(tracked_actor=tracked_actor)
|
| 1389 |
+
|
| 1390 |
+
self._mark_trial_to_checkpoint(trial)
|
| 1391 |
+
|
| 1392 |
+
if not exception and self._maybe_cache_trial_actor(trial):
|
| 1393 |
+
# Trial runner has been cached
|
| 1394 |
+
return
|
| 1395 |
+
|
| 1396 |
+
logger.debug(f"Terminating actor for trial {trial}: {tracked_actor}")
|
| 1397 |
+
|
| 1398 |
+
tracked_actor = self._trial_to_actor.pop(trial)
|
| 1399 |
+
self._actor_to_trial.pop(tracked_actor)
|
| 1400 |
+
|
| 1401 |
+
trial.set_ray_actor(None)
|
| 1402 |
+
|
| 1403 |
+
self._remove_actor(tracked_actor=tracked_actor)
|
| 1404 |
+
|
| 1405 |
+
def stop_trial(self, trial):
|
| 1406 |
+
"""The canonical implementation of stopping a trial.
|
| 1407 |
+
|
| 1408 |
+
Trials may be in any external status when this function is called.
|
| 1409 |
+
If trial is in state PENDING or PAUSED, calls `on_trial_remove` for
|
| 1410 |
+
scheduler and `on_trial_complete()` for search_alg.
|
| 1411 |
+
If trial is in state RUNNING, calls `on_trial_complete` for scheduler
|
| 1412 |
+
and search_alg if RUNNING. Caller to ensure that there is no
|
| 1413 |
+
outstanding future to be handled for the trial. If there is, the future
|
| 1414 |
+
would be discarded.
|
| 1415 |
+
"""
|
| 1416 |
+
try:
|
| 1417 |
+
if trial.status in [Trial.ERROR, Trial.TERMINATED]:
|
| 1418 |
+
return
|
| 1419 |
+
elif trial.status in [Trial.PENDING, Trial.PAUSED]:
|
| 1420 |
+
self._scheduler_alg.on_trial_remove(self, trial)
|
| 1421 |
+
self._search_alg.on_trial_complete(trial.trial_id)
|
| 1422 |
+
elif trial.status is Trial.RUNNING:
|
| 1423 |
+
# By this time trial.last_result should have been
|
| 1424 |
+
# updated already.
|
| 1425 |
+
self._scheduler_alg.on_trial_complete(
|
| 1426 |
+
self, trial, flatten_dict(trial.last_result)
|
| 1427 |
+
)
|
| 1428 |
+
self._search_alg.on_trial_complete(
|
| 1429 |
+
trial.trial_id, result=flatten_dict(trial.last_result)
|
| 1430 |
+
)
|
| 1431 |
+
self._callbacks.on_trial_complete(
|
| 1432 |
+
iteration=self._iteration, trials=self._trials, trial=trial
|
| 1433 |
+
)
|
| 1434 |
+
self._schedule_graceful_trial_stop(trial)
|
| 1435 |
+
self._live_trials.discard(trial)
|
| 1436 |
+
except Exception as e:
|
| 1437 |
+
logger.exception("Trial %s: Error stopping trial.", trial)
|
| 1438 |
+
if self._fail_fast == self.RAISE:
|
| 1439 |
+
raise
|
| 1440 |
+
if isinstance(e, TuneError):
|
| 1441 |
+
self._process_trial_failure(trial, exception=e)
|
| 1442 |
+
else:
|
| 1443 |
+
self._process_trial_failure(
|
| 1444 |
+
trial, _TuneStopTrialError(traceback.format_exc())
|
| 1445 |
+
)
|
| 1446 |
+
|
| 1447 |
+
def _schedule_graceful_trial_stop(self, trial: Trial):
|
| 1448 |
+
self._schedule_trial_export(trial)
|
| 1449 |
+
if trial.status != "ERROR":
|
| 1450 |
+
self._schedule_trial_stop(trial)
|
| 1451 |
+
|
| 1452 |
+
def _schedule_trial_pause(self, trial: Trial, should_checkpoint: bool = True):
|
| 1453 |
+
if trial not in self._trial_to_actor:
|
| 1454 |
+
logger.debug(
|
| 1455 |
+
f"Trial PAUSE requested for trial {trial} but trial is already "
|
| 1456 |
+
f"stopping. Ignoring."
|
| 1457 |
+
)
|
| 1458 |
+
return
|
| 1459 |
+
|
| 1460 |
+
if should_checkpoint:
|
| 1461 |
+
self._cached_trial_decisions[trial.trial_id] = TrialScheduler.PAUSE
|
| 1462 |
+
self._schedule_trial_save(trial=trial)
|
| 1463 |
+
else:
|
| 1464 |
+
self._schedule_trial_stop(trial)
|
| 1465 |
+
self._set_trial_status(trial, Trial.PAUSED)
|
| 1466 |
+
|
| 1467 |
+
###
|
| 1468 |
+
# TRAIN
|
| 1469 |
+
|
| 1470 |
+
def _schedule_trial_train(self, trial: Trial):
|
| 1471 |
+
args = ()
|
| 1472 |
+
method_name = "train"
|
| 1473 |
+
|
| 1474 |
+
buffer_length, buffer_time_s = self._maybe_buffer_training(trial)
|
| 1475 |
+
|
| 1476 |
+
if buffer_length > 1:
|
| 1477 |
+
method_name = "train_buffered"
|
| 1478 |
+
args = (buffer_length, buffer_time_s)
|
| 1479 |
+
|
| 1480 |
+
logger.debug(f"Scheduling future {method_name.upper()} for trial {trial}")
|
| 1481 |
+
|
| 1482 |
+
self._schedule_trial_task(
|
| 1483 |
+
trial=trial,
|
| 1484 |
+
method_name=method_name,
|
| 1485 |
+
args=args,
|
| 1486 |
+
on_result=self._on_training_result,
|
| 1487 |
+
on_error=self._trial_task_failure,
|
| 1488 |
+
)
|
| 1489 |
+
|
| 1490 |
+
def _maybe_buffer_training(self, trial: Trial) -> Tuple[int, float]:
|
| 1491 |
+
buffer_time_s = max(
|
| 1492 |
+
self._buffer_min_time_s,
|
| 1493 |
+
min(self._buffer_max_time_s, self._actor_manager.num_actor_tasks // 10),
|
| 1494 |
+
)
|
| 1495 |
+
buffer_length = self._buffer_length
|
| 1496 |
+
|
| 1497 |
+
if buffer_length > 1 and trial.checkpoint_at_end:
|
| 1498 |
+
# If a trial checkpoint can be triggered externally,
|
| 1499 |
+
# it is not safe to buffer results.
|
| 1500 |
+
if log_once("trial_executor_buffer_checkpoint"):
|
| 1501 |
+
logger.warning(
|
| 1502 |
+
"Disabling buffered training as you passed "
|
| 1503 |
+
"`checkpoint_at_end` to `train.CheckpointConfig()`."
|
| 1504 |
+
)
|
| 1505 |
+
return 1, buffer_time_s
|
| 1506 |
+
|
| 1507 |
+
if buffer_length > 1 and trial.checkpoint_freq > 0:
|
| 1508 |
+
return min(buffer_length, trial.checkpoint_freq), buffer_time_s
|
| 1509 |
+
|
| 1510 |
+
return buffer_length, buffer_time_s
|
| 1511 |
+
|
| 1512 |
+
###
|
| 1513 |
+
# RESULT
|
| 1514 |
+
|
| 1515 |
+
def _on_training_result(self, trial, result):
|
| 1516 |
+
if not isinstance(result, list):
|
| 1517 |
+
result = [result]
|
| 1518 |
+
with warn_if_slow("process_trial_result"):
|
| 1519 |
+
self._process_trial_results(trial, result)
|
| 1520 |
+
self._maybe_execute_queued_decision(trial, after_save=False)
|
| 1521 |
+
|
| 1522 |
+
def _process_trial_results(self, trial, results):
|
| 1523 |
+
logger.debug(f"Processing trial results for trial {trial}: {results}")
|
| 1524 |
+
with warn_if_slow(
|
| 1525 |
+
"process_trial_results",
|
| 1526 |
+
message="Processing trial results took {duration:.3f} s, "
|
| 1527 |
+
"which may be a performance bottleneck. Please consider "
|
| 1528 |
+
"reporting results less frequently to Ray Tune.",
|
| 1529 |
+
):
|
| 1530 |
+
for i, result in enumerate(results):
|
| 1531 |
+
with warn_if_slow("process_trial_result"):
|
| 1532 |
+
decision = self._process_trial_result(trial, result)
|
| 1533 |
+
if decision is None:
|
| 1534 |
+
# If we didn't get a decision, this means a
|
| 1535 |
+
# non-training future (e.g. a save) was scheduled.
|
| 1536 |
+
# We do not allow processing more results then.
|
| 1537 |
+
if i < len(results) - 1:
|
| 1538 |
+
if log_once("tune_controller_buffer_checkpoint"):
|
| 1539 |
+
logger.warning(
|
| 1540 |
+
f"Trial {trial} has a non-training future "
|
| 1541 |
+
f"scheduled but {len(results) - i} results "
|
| 1542 |
+
f"left to process. This means that a "
|
| 1543 |
+
f"checkpoint was requested, but buffered "
|
| 1544 |
+
f"training was continued before it was "
|
| 1545 |
+
f"saved. Consider using non-buffered "
|
| 1546 |
+
f"training by setting the env variable "
|
| 1547 |
+
f"`TUNE_RESULT_BUFFER_LENGTH=1`."
|
| 1548 |
+
)
|
| 1549 |
+
elif decision == TrialScheduler.STOP:
|
| 1550 |
+
# If the decision is to stop the trial,
|
| 1551 |
+
# ignore all results that came after that.
|
| 1552 |
+
break
|
| 1553 |
+
|
| 1554 |
+
def _process_trial_result(self, trial, result):
|
| 1555 |
+
result.update(trial_id=trial.trial_id)
|
| 1556 |
+
is_duplicate = RESULT_DUPLICATE in result
|
| 1557 |
+
force_checkpoint = result.get(SHOULD_CHECKPOINT, False)
|
| 1558 |
+
# TrialScheduler and SearchAlgorithm still receive a
|
| 1559 |
+
# notification because there may be special handling for
|
| 1560 |
+
# the `on_trial_complete` hook.
|
| 1561 |
+
if is_duplicate:
|
| 1562 |
+
logger.debug("Trial finished without logging 'done'.")
|
| 1563 |
+
result = trial.last_result
|
| 1564 |
+
result.update(done=True)
|
| 1565 |
+
|
| 1566 |
+
self._total_time += result.get(TIME_THIS_ITER_S, 0)
|
| 1567 |
+
|
| 1568 |
+
flat_result = flatten_dict(result)
|
| 1569 |
+
self._validate_result_metrics(flat_result)
|
| 1570 |
+
|
| 1571 |
+
if self._stopper(trial.trial_id, result) or trial.should_stop(flat_result):
|
| 1572 |
+
decision = TrialScheduler.STOP
|
| 1573 |
+
else:
|
| 1574 |
+
with warn_if_slow("scheduler.on_trial_result"):
|
| 1575 |
+
decision = self._scheduler_alg.on_trial_result(
|
| 1576 |
+
self._wrapped(), trial, flat_result
|
| 1577 |
+
)
|
| 1578 |
+
if decision == TrialScheduler.STOP:
|
| 1579 |
+
result.update(done=True)
|
| 1580 |
+
else:
|
| 1581 |
+
# Only updating search alg if the trial is not to be stopped.
|
| 1582 |
+
with warn_if_slow("search_alg.on_trial_result"):
|
| 1583 |
+
self._search_alg.on_trial_result(trial.trial_id, flat_result)
|
| 1584 |
+
|
| 1585 |
+
# If this is not a duplicate result, the callbacks should
|
| 1586 |
+
# be informed about the result.
|
| 1587 |
+
if not is_duplicate:
|
| 1588 |
+
with warn_if_slow("callbacks.on_trial_result"):
|
| 1589 |
+
self._callbacks.on_trial_result(
|
| 1590 |
+
iteration=self._iteration,
|
| 1591 |
+
trials=self._trials,
|
| 1592 |
+
trial=trial,
|
| 1593 |
+
result=result.copy(),
|
| 1594 |
+
)
|
| 1595 |
+
trial.update_last_result(result)
|
| 1596 |
+
# Include in next experiment checkpoint
|
| 1597 |
+
self._mark_trial_to_checkpoint(trial)
|
| 1598 |
+
|
| 1599 |
+
# Checkpoints to disk. This should be checked even if
|
| 1600 |
+
# the scheduler decision is STOP or PAUSE. Note that
|
| 1601 |
+
# PAUSE only checkpoints to memory and does not update
|
| 1602 |
+
# the global checkpoint state.
|
| 1603 |
+
if decision != TrialScheduler.PAUSE:
|
| 1604 |
+
# TODO(justinvyu): This is a temporary hack to fix pausing trials.
|
| 1605 |
+
# We already schedule a save task in `pause_trial`, so no need
|
| 1606 |
+
# to do it again here.
|
| 1607 |
+
self._checkpoint_trial_if_needed(trial, force=force_checkpoint)
|
| 1608 |
+
|
| 1609 |
+
if trial.is_saving:
|
| 1610 |
+
logger.debug(f"Caching trial decision for trial {trial}: {decision}")
|
| 1611 |
+
# Cache decision to execute on after the save is processed.
|
| 1612 |
+
# This prevents changing the trial's state or kicking off
|
| 1613 |
+
# another training step prematurely.
|
| 1614 |
+
if not self._cached_trial_decisions.get(trial.trial_id) or decision in {
|
| 1615 |
+
TrialScheduler.PAUSE,
|
| 1616 |
+
TrialScheduler.STOP,
|
| 1617 |
+
}:
|
| 1618 |
+
# If already set, only overwrite if it's a PAUSE or STOP. This is
|
| 1619 |
+
# to avoid that CONTINUE decisions from a training step that resolve
|
| 1620 |
+
# late overwrite PAUSE/STOP decision.
|
| 1621 |
+
self._cached_trial_decisions[trial.trial_id] = decision
|
| 1622 |
+
return None
|
| 1623 |
+
else:
|
| 1624 |
+
self._queue_decision(trial, decision)
|
| 1625 |
+
return decision
|
| 1626 |
+
|
| 1627 |
+
def _validate_result_metrics(self, result):
|
| 1628 |
+
"""
|
| 1629 |
+
Check if any of the required metrics was not reported
|
| 1630 |
+
in the last result. If the only items are ``done`` or any of
|
| 1631 |
+
DEBUG_METRICS, this means that no result was ever received and
|
| 1632 |
+
the trial just returned. This is also okay and will not raise
|
| 1633 |
+
an error.
|
| 1634 |
+
|
| 1635 |
+
This will ignore checking for the DEFAULT_METRIC.
|
| 1636 |
+
"""
|
| 1637 |
+
if int(os.environ.get("TUNE_DISABLE_STRICT_METRIC_CHECKING", 0)) != 1 and (
|
| 1638 |
+
len({k for k in result if k not in list(DEBUG_METRICS) + [DONE]}) > 1
|
| 1639 |
+
):
|
| 1640 |
+
base_metric = self._metric if self._metric != DEFAULT_METRIC else None
|
| 1641 |
+
scheduler_metric = (
|
| 1642 |
+
self._scheduler_alg.metric
|
| 1643 |
+
if self._scheduler_alg.metric != DEFAULT_METRIC
|
| 1644 |
+
else None
|
| 1645 |
+
)
|
| 1646 |
+
search_metrics = (
|
| 1647 |
+
self._search_alg.metric
|
| 1648 |
+
if self._search_alg.metric != DEFAULT_METRIC
|
| 1649 |
+
else None
|
| 1650 |
+
)
|
| 1651 |
+
|
| 1652 |
+
if isinstance(search_metrics, str):
|
| 1653 |
+
search_metrics = [search_metrics]
|
| 1654 |
+
|
| 1655 |
+
if base_metric and base_metric not in result:
|
| 1656 |
+
report_metric = base_metric
|
| 1657 |
+
location = "tune.TuneConfig()"
|
| 1658 |
+
elif scheduler_metric and scheduler_metric not in result:
|
| 1659 |
+
report_metric = scheduler_metric
|
| 1660 |
+
location = type(self._scheduler_alg).__name__
|
| 1661 |
+
elif search_metrics and any(
|
| 1662 |
+
search_metric not in result for search_metric in search_metrics
|
| 1663 |
+
):
|
| 1664 |
+
report_metric = list(
|
| 1665 |
+
filter(
|
| 1666 |
+
lambda search_metric: search_metric not in result,
|
| 1667 |
+
search_metrics,
|
| 1668 |
+
)
|
| 1669 |
+
)
|
| 1670 |
+
if len(report_metric) == 1:
|
| 1671 |
+
report_metric = report_metric[0]
|
| 1672 |
+
location = type(self._search_alg).__name__
|
| 1673 |
+
else:
|
| 1674 |
+
report_metric = None
|
| 1675 |
+
location = None
|
| 1676 |
+
|
| 1677 |
+
if report_metric:
|
| 1678 |
+
raise ValueError(
|
| 1679 |
+
"Trial returned a result which did not include the "
|
| 1680 |
+
"specified metric(s) `{}` that `{}` expects. "
|
| 1681 |
+
"Make sure your calls to `tune.report()` include the "
|
| 1682 |
+
"metric, or set the "
|
| 1683 |
+
"TUNE_DISABLE_STRICT_METRIC_CHECKING "
|
| 1684 |
+
"environment variable to 1. Result: {}".format(
|
| 1685 |
+
report_metric, location, result
|
| 1686 |
+
)
|
| 1687 |
+
)
|
| 1688 |
+
|
| 1689 |
+
###
|
| 1690 |
+
# SAVE
|
| 1691 |
+
def _schedule_trial_save(
|
| 1692 |
+
self,
|
| 1693 |
+
trial: Trial,
|
| 1694 |
+
result: Optional[Dict] = None,
|
| 1695 |
+
) -> Optional[_FutureTrainingResult]:
|
| 1696 |
+
if trial not in self._trial_to_actor:
|
| 1697 |
+
logger.debug(
|
| 1698 |
+
f"Trial SAVE requested for trial {trial} but trial is already "
|
| 1699 |
+
f"stopping. Ignoring."
|
| 1700 |
+
)
|
| 1701 |
+
return None
|
| 1702 |
+
|
| 1703 |
+
result = result or trial.last_result
|
| 1704 |
+
|
| 1705 |
+
future = self._schedule_trial_task(
|
| 1706 |
+
trial=trial,
|
| 1707 |
+
method_name="save",
|
| 1708 |
+
on_result=self._on_saving_result,
|
| 1709 |
+
on_error=self._trial_task_failure,
|
| 1710 |
+
_return_future=True,
|
| 1711 |
+
)
|
| 1712 |
+
# TODO(justinvyu): `trial.saving_to` (and trial.is_saving) is needed
|
| 1713 |
+
# in order to prevent a done=True result from executing a STOP decision
|
| 1714 |
+
# (which clears all futures) before the save gets processed.
|
| 1715 |
+
# Keep this in for now while `train` and `save` are 2 separate steps.
|
| 1716 |
+
trial.temporary_state.saving_to = _FutureTrainingResult(future)
|
| 1717 |
+
|
| 1718 |
+
# `trial.saving_to` holds a future training result -- this is only used
|
| 1719 |
+
# in the case of PBT to block until the checkpoint is ready.
|
| 1720 |
+
# In all other situations, the checkpoint future is processed by the
|
| 1721 |
+
# actor event manager when it is ready.
|
| 1722 |
+
return trial.temporary_state.saving_to
|
| 1723 |
+
|
| 1724 |
+
def _on_saving_result(self, trial, checkpoint_value: _TrainingResult):
|
| 1725 |
+
with warn_if_slow("process_trial_save"):
|
| 1726 |
+
self._process_trial_save(trial, checkpoint_value)
|
| 1727 |
+
|
| 1728 |
+
with warn_if_slow("callbacks.on_trial_save"):
|
| 1729 |
+
self._callbacks.on_trial_save(
|
| 1730 |
+
iteration=self._iteration, trials=self._trials, trial=trial
|
| 1731 |
+
)
|
| 1732 |
+
|
| 1733 |
+
self._maybe_execute_queued_decision(trial, after_save=True)
|
| 1734 |
+
|
| 1735 |
+
def _process_trial_save(self, trial: Trial, checkpoint_value: _TrainingResult):
|
| 1736 |
+
"""Processes a trial save.
|
| 1737 |
+
|
| 1738 |
+
Acts on the decision cached during the last `_process_trial` call.
|
| 1739 |
+
|
| 1740 |
+
Args:
|
| 1741 |
+
trial: Trial being saved.
|
| 1742 |
+
"""
|
| 1743 |
+
logger.debug("Trial %s: Processing trial save.", trial)
|
| 1744 |
+
|
| 1745 |
+
try:
|
| 1746 |
+
if not checkpoint_value.checkpoint:
|
| 1747 |
+
logger.debug(f"Got empty checkpoint for trial {trial}")
|
| 1748 |
+
else:
|
| 1749 |
+
try:
|
| 1750 |
+
self._callbacks.on_checkpoint(
|
| 1751 |
+
iteration=self._iteration,
|
| 1752 |
+
trials=self._trials,
|
| 1753 |
+
trial=trial,
|
| 1754 |
+
checkpoint=checkpoint_value.checkpoint,
|
| 1755 |
+
)
|
| 1756 |
+
except Exception:
|
| 1757 |
+
logger.warning(
|
| 1758 |
+
"Error encountered during processing of callbacks. "
|
| 1759 |
+
"Ray Train/Tune recently changed the checkpoint interface "
|
| 1760 |
+
"that is passed to callbacks. If you implemented your own "
|
| 1761 |
+
"callback with an `on_checkpoint` handler, please review "
|
| 1762 |
+
"the checkpoint interface and adjust your code "
|
| 1763 |
+
"accordingly."
|
| 1764 |
+
)
|
| 1765 |
+
raise
|
| 1766 |
+
|
| 1767 |
+
trial.on_checkpoint(checkpoint_value)
|
| 1768 |
+
|
| 1769 |
+
self._checkpoint_manager.on_trial_checkpoint(trial)
|
| 1770 |
+
|
| 1771 |
+
self._mark_trial_to_checkpoint(trial)
|
| 1772 |
+
except Exception:
|
| 1773 |
+
logger.exception(
|
| 1774 |
+
"Trial %s: Error handling checkpoint %s", trial, checkpoint_value
|
| 1775 |
+
)
|
| 1776 |
+
|
| 1777 |
+
trial.temporary_state.saving_to = None
|
| 1778 |
+
decision = self._cached_trial_decisions.pop(trial.trial_id, None)
|
| 1779 |
+
if decision and checkpoint_value:
|
| 1780 |
+
self._queue_decision(trial, decision)
|
| 1781 |
+
|
| 1782 |
+
def _checkpoint_trial_if_needed(self, trial, force=False):
|
| 1783 |
+
"""Checkpoints trial based off trial.last_result."""
|
| 1784 |
+
if trial.should_checkpoint() or force:
|
| 1785 |
+
# Save trial runtime if possible.
|
| 1786 |
+
if trial.temporary_state.ray_actor:
|
| 1787 |
+
self._schedule_trial_save(trial)
|
| 1788 |
+
|
| 1789 |
+
###
|
| 1790 |
+
# RESTORE
|
| 1791 |
+
def _schedule_trial_restore(self, trial: Trial) -> bool:
|
| 1792 |
+
checkpoint_result = trial.latest_checkpoint_result
|
| 1793 |
+
|
| 1794 |
+
if not checkpoint_result:
|
| 1795 |
+
logger.debug(f"Not restoring trial {trial}: No checkpoint found.")
|
| 1796 |
+
return False
|
| 1797 |
+
|
| 1798 |
+
# TODO(justinvyu): Is this really needed?
|
| 1799 |
+
trial.temporary_state.restoring_from = checkpoint_result
|
| 1800 |
+
|
| 1801 |
+
method_name = "restore"
|
| 1802 |
+
args = (checkpoint_result,)
|
| 1803 |
+
self._schedule_trial_task(
|
| 1804 |
+
trial=trial,
|
| 1805 |
+
method_name=method_name,
|
| 1806 |
+
args=args,
|
| 1807 |
+
kwargs={},
|
| 1808 |
+
on_result=self._on_restoring_result,
|
| 1809 |
+
on_error=self._trial_task_failure,
|
| 1810 |
+
)
|
| 1811 |
+
return True
|
| 1812 |
+
|
| 1813 |
+
def _on_restoring_result(self, trial: Trial, result: Any):
|
| 1814 |
+
self._process_trial_restore(trial)
|
| 1815 |
+
|
| 1816 |
+
def _process_trial_restore(self, trial: Trial):
|
| 1817 |
+
"""Processes a trial restore.
|
| 1818 |
+
|
| 1819 |
+
Args:
|
| 1820 |
+
trial: Trial being restored.
|
| 1821 |
+
"""
|
| 1822 |
+
logger.debug("Trial %s: Processing trial restore.", trial)
|
| 1823 |
+
trial.on_restore()
|
| 1824 |
+
logger.debug("Trial %s: Restore processed successfully", trial)
|
| 1825 |
+
self._set_trial_status(trial, Trial.RUNNING)
|
| 1826 |
+
self._schedule_trial_train(trial)
|
| 1827 |
+
self._live_trials.add(trial)
|
| 1828 |
+
|
| 1829 |
+
def _try_recover(
|
| 1830 |
+
self, trial: Trial, exc: Union[TuneError, RayTaskError, RayActorError]
|
| 1831 |
+
):
|
| 1832 |
+
"""Tries to recover trial.
|
| 1833 |
+
|
| 1834 |
+
Notifies SearchAlgorithm and Scheduler if failure to recover.
|
| 1835 |
+
|
| 1836 |
+
Args:
|
| 1837 |
+
trial: Trial to recover.
|
| 1838 |
+
exc: Exception prior to invoking this method.
|
| 1839 |
+
"""
|
| 1840 |
+
self._cached_trial_decisions.pop(trial.trial_id, None)
|
| 1841 |
+
# Resetting this, in case that the trial is in saving status when it crashes.
|
| 1842 |
+
if trial.is_saving:
|
| 1843 |
+
trial.temporary_state.saving_to = None
|
| 1844 |
+
self._schedule_trial_stop(trial, exception=exc)
|
| 1845 |
+
|
| 1846 |
+
logger.debug("Trial %s: Notifying Scheduler and requeueing.", trial)
|
| 1847 |
+
self._requeue_trial(trial)
|
| 1848 |
+
|
| 1849 |
+
def _requeue_trial(self, trial):
|
| 1850 |
+
"""Notification to TrialScheduler and requeue trial.
|
| 1851 |
+
|
| 1852 |
+
This does not notify the SearchAlgorithm because the function
|
| 1853 |
+
evaluation is still in progress.
|
| 1854 |
+
|
| 1855 |
+
"""
|
| 1856 |
+
self._scheduler_alg.on_trial_error(self, trial)
|
| 1857 |
+
self._set_trial_status(trial, status=Trial.PENDING)
|
| 1858 |
+
|
| 1859 |
+
# TODO(rliaw): Right now, this pushes the trial to the end of queue
|
| 1860 |
+
# because restoration can be expensive. However, this is not
|
| 1861 |
+
# ideal since it just hides the issue - a better fix would
|
| 1862 |
+
# be to use an actor table to detect the IP of the Trainable
|
| 1863 |
+
# and rsync the files there.
|
| 1864 |
+
# See https://github.com/ray-project/ray/issues/5168
|
| 1865 |
+
self._trials.pop(self._trials.index(trial))
|
| 1866 |
+
self._trials.append(trial)
|
| 1867 |
+
self._live_trials.add(trial)
|
| 1868 |
+
|
| 1869 |
+
with warn_if_slow("scheduler.on_trial_add"):
|
| 1870 |
+
self._scheduler_alg.on_trial_add(self._wrapped(), trial)
|
| 1871 |
+
|
| 1872 |
+
###
|
| 1873 |
+
# EXPORT
|
| 1874 |
+
def _schedule_trial_export(self, trial: Trial):
|
| 1875 |
+
if not trial.export_formats or len(trial.export_formats) <= 0:
|
| 1876 |
+
return
|
| 1877 |
+
|
| 1878 |
+
# Todo: We are waiting here synchronously until the task resolved.
|
| 1879 |
+
# Instead, we should schedule the trial stop after the export resolved.
|
| 1880 |
+
# This requires changes in TrialRunner, which we can remove once the
|
| 1881 |
+
# legacy execution path has been removed.
|
| 1882 |
+
future = self._schedule_trial_task(
|
| 1883 |
+
trial=trial,
|
| 1884 |
+
method_name="export_model",
|
| 1885 |
+
args=(trial.export_formats,),
|
| 1886 |
+
on_result=None,
|
| 1887 |
+
on_error=self._trial_task_failure,
|
| 1888 |
+
_return_future=True,
|
| 1889 |
+
)
|
| 1890 |
+
self._actor_manager._actor_task_events.resolve_future(future)
|
| 1891 |
+
|
| 1892 |
+
###
|
| 1893 |
+
# RESET
|
| 1894 |
+
def _schedule_trial_reset(
|
| 1895 |
+
self,
|
| 1896 |
+
trial: Trial,
|
| 1897 |
+
new_config: Dict,
|
| 1898 |
+
new_experiment_tag: str,
|
| 1899 |
+
):
|
| 1900 |
+
trial.set_experiment_tag(new_experiment_tag)
|
| 1901 |
+
trial.set_config(new_config)
|
| 1902 |
+
|
| 1903 |
+
# Pass magic variables
|
| 1904 |
+
extra_config = copy.deepcopy(new_config)
|
| 1905 |
+
extra_config[TRIAL_INFO] = _TrialInfo(trial)
|
| 1906 |
+
|
| 1907 |
+
stdout_file, stderr_file = trial.log_to_file
|
| 1908 |
+
extra_config[STDOUT_FILE] = stdout_file
|
| 1909 |
+
extra_config[STDERR_FILE] = stderr_file
|
| 1910 |
+
|
| 1911 |
+
logger_creator = partial(
|
| 1912 |
+
_noop_logger_creator, logdir=trial.storage.trial_working_directory
|
| 1913 |
+
)
|
| 1914 |
+
|
| 1915 |
+
self._resetting_trials.add(trial)
|
| 1916 |
+
self._schedule_trial_task(
|
| 1917 |
+
trial=trial,
|
| 1918 |
+
method_name="reset",
|
| 1919 |
+
args=(extra_config,),
|
| 1920 |
+
kwargs={
|
| 1921 |
+
"logger_creator": logger_creator,
|
| 1922 |
+
"storage": trial.storage,
|
| 1923 |
+
},
|
| 1924 |
+
on_result=self._on_trial_reset,
|
| 1925 |
+
on_error=self._trial_task_failure,
|
| 1926 |
+
)
|
| 1927 |
+
|
| 1928 |
+
def _on_trial_reset(self, trial: Trial, success: bool):
|
| 1929 |
+
self._resetting_trials.remove(trial)
|
| 1930 |
+
|
| 1931 |
+
if not success:
|
| 1932 |
+
info = (
|
| 1933 |
+
"Trainable runner reuse requires reset_config() to be "
|
| 1934 |
+
"implemented and return True."
|
| 1935 |
+
)
|
| 1936 |
+
|
| 1937 |
+
logger.error(f"Could not re-use actor for trial {trial}: {info}")
|
| 1938 |
+
|
| 1939 |
+
exception = _AbortTrialExecution(info)
|
| 1940 |
+
|
| 1941 |
+
trial.handle_error(exception)
|
| 1942 |
+
self._schedule_trial_stop(trial, exception=exception)
|
| 1943 |
+
return
|
| 1944 |
+
|
| 1945 |
+
tracked_actor = self._trial_to_actor[trial]
|
| 1946 |
+
|
| 1947 |
+
self._actor_started(tracked_actor, log="REUSED")
|
| 1948 |
+
|
| 1949 |
+
def request_stop_trial(self, trial):
|
| 1950 |
+
self._stop_queue.append(trial)
|
| 1951 |
+
|
| 1952 |
+
def request_stop_experiment(self):
|
| 1953 |
+
self._should_stop_experiment = True
|
| 1954 |
+
|
| 1955 |
+
def _process_stop_requests(self):
|
| 1956 |
+
while self._stop_queue:
|
| 1957 |
+
t = self._stop_queue.pop()
|
| 1958 |
+
self.stop_trial(t)
|
| 1959 |
+
|
| 1960 |
+
def pause_trial(self, trial: Trial, should_checkpoint: bool = True):
|
| 1961 |
+
"""Pause a trial and reset the necessary state variables for resuming later.
|
| 1962 |
+
|
| 1963 |
+
Args:
|
| 1964 |
+
trial: Trial to pause.
|
| 1965 |
+
should_checkpoint: Whether or not an in-memory checkpoint should be created
|
| 1966 |
+
for this paused trial. Defaults to True.
|
| 1967 |
+
"""
|
| 1968 |
+
# NOTE: The cached trial decision is not needed since we will overrule this
|
| 1969 |
+
# decision with PAUSE.
|
| 1970 |
+
self._cached_trial_decisions.pop(trial.trial_id, None)
|
| 1971 |
+
self._schedule_trial_pause(trial, should_checkpoint=should_checkpoint)
|
| 1972 |
+
|
| 1973 |
+
def cleanup(self):
|
| 1974 |
+
"""Cleanup trials and callbacks."""
|
| 1975 |
+
self._cleanup_trials()
|
| 1976 |
+
self.end_experiment_callbacks()
|
| 1977 |
+
|
| 1978 |
+
def __getstate__(self):
|
| 1979 |
+
"""Gets state for trial.
|
| 1980 |
+
|
| 1981 |
+
Note that this is not used as a pickling override as
|
| 1982 |
+
does not have all fields.
|
| 1983 |
+
"""
|
| 1984 |
+
state = self.__dict__.copy()
|
| 1985 |
+
for k in [
|
| 1986 |
+
"_trials",
|
| 1987 |
+
"_live_trials",
|
| 1988 |
+
"_stop_queue",
|
| 1989 |
+
"_search_alg",
|
| 1990 |
+
"_placeholder_resolvers",
|
| 1991 |
+
"_scheduler_alg",
|
| 1992 |
+
"_pending_trial_queue_times",
|
| 1993 |
+
"_callbacks",
|
| 1994 |
+
"_checkpoint_manager",
|
| 1995 |
+
"_storage",
|
| 1996 |
+
"_insufficient_resources_manager",
|
| 1997 |
+
"_actor_manager",
|
| 1998 |
+
"_class_cache",
|
| 1999 |
+
"_resource_updater",
|
| 2000 |
+
"_trials_to_cache",
|
| 2001 |
+
"_trial_metadata",
|
| 2002 |
+
"_actor_to_trial",
|
| 2003 |
+
"_trial_to_actor",
|
| 2004 |
+
"_resources_to_pending_trials",
|
| 2005 |
+
"_pending_trials",
|
| 2006 |
+
"_pending_trials_list",
|
| 2007 |
+
"_running_trials",
|
| 2008 |
+
"_paused_trials",
|
| 2009 |
+
"_stopped_trials",
|
| 2010 |
+
"_failed_trials",
|
| 2011 |
+
"_resetting_trials",
|
| 2012 |
+
"_started_actors",
|
| 2013 |
+
"_stopping_actors",
|
| 2014 |
+
"_staged_trials",
|
| 2015 |
+
"_actor_cache",
|
| 2016 |
+
]:
|
| 2017 |
+
del state[k]
|
| 2018 |
+
return state
|
| 2019 |
+
|
| 2020 |
+
def __setstate__(self, state):
|
| 2021 |
+
# Use session_str from previous checkpoint if does not exist
|
| 2022 |
+
session_str = state.pop("_session_str")
|
| 2023 |
+
self.__dict__.setdefault("_session_str", session_str)
|
| 2024 |
+
# Use start_time from previous checkpoint if does not exist
|
| 2025 |
+
start_time = state.pop("_start_time")
|
| 2026 |
+
self.__dict__.setdefault("_start_time", start_time)
|
| 2027 |
+
|
| 2028 |
+
self.__dict__.update(state)
|
| 2029 |
+
self._checkpoint_manager = self._create_checkpoint_manager()
|
| 2030 |
+
|
| 2031 |
+
|
| 2032 |
+
class _TrialExecutorWrapper:
|
| 2033 |
+
"""Wraps around TrialExecutor class, intercepts API calls and warns users
|
| 2034 |
+
of restricted API access.
|
| 2035 |
+
|
| 2036 |
+
This is meant to facilitate restricting
|
| 2037 |
+
the current API exposure of TrialExecutor by TrialScheduler.
|
| 2038 |
+
"""
|
| 2039 |
+
|
| 2040 |
+
def __init__(
|
| 2041 |
+
self,
|
| 2042 |
+
trial_executor: "_FakeRayTrialExecutor",
|
| 2043 |
+
whitelist_attr: Optional[set] = None,
|
| 2044 |
+
):
|
| 2045 |
+
self._trial_executor = trial_executor
|
| 2046 |
+
self._whitelist_attr = whitelist_attr or set()
|
| 2047 |
+
|
| 2048 |
+
for attr in self._whitelist_attr:
|
| 2049 |
+
assert hasattr(self._trial_executor, attr)
|
| 2050 |
+
|
| 2051 |
+
def __getattr__(self, attr):
|
| 2052 |
+
if attr not in self._whitelist_attr:
|
| 2053 |
+
if log_once("restrict_accessing_trial_executor"):
|
| 2054 |
+
logger.warning(
|
| 2055 |
+
f"You are trying to access {attr} interface of "
|
| 2056 |
+
f"TrialExecutor in TrialScheduler, which is being "
|
| 2057 |
+
f"restricted. If you believe it is reasonable for "
|
| 2058 |
+
f"your scheduler to access this TrialExecutor API, "
|
| 2059 |
+
f"please reach out to Ray team on GitHub. A more "
|
| 2060 |
+
f"strict API access pattern would be enforced "
|
| 2061 |
+
f"starting 1.12.0"
|
| 2062 |
+
)
|
| 2063 |
+
return getattr(self._trial_executor, attr)
|
| 2064 |
+
|
| 2065 |
+
|
| 2066 |
+
@DeveloperAPI
|
| 2067 |
+
class TrialRunnerWrapper:
|
| 2068 |
+
"""Wraps around TrialRunner class, intercepts API calls and warns users
|
| 2069 |
+
of restricted API access.
|
| 2070 |
+
|
| 2071 |
+
This is meant to facilitate restricting
|
| 2072 |
+
the current API exposure of TrialRunner by TrialScheduler.
|
| 2073 |
+
"""
|
| 2074 |
+
|
| 2075 |
+
_EXECUTOR_ATTR = "trial_executor"
|
| 2076 |
+
|
| 2077 |
+
def __init__(
|
| 2078 |
+
self,
|
| 2079 |
+
tune_controller: TuneController,
|
| 2080 |
+
trial_executor: Any,
|
| 2081 |
+
runner_whitelist_attr: Optional[set] = None,
|
| 2082 |
+
executor_whitelist_attr: Optional[set] = None,
|
| 2083 |
+
):
|
| 2084 |
+
self._tune_controller = tune_controller
|
| 2085 |
+
self._trial_executor = _TrialExecutorWrapper(
|
| 2086 |
+
trial_executor, executor_whitelist_attr
|
| 2087 |
+
)
|
| 2088 |
+
self._runner_whitelist_attr = runner_whitelist_attr or set()
|
| 2089 |
+
|
| 2090 |
+
for attr in self._runner_whitelist_attr:
|
| 2091 |
+
assert hasattr(self, attr)
|
| 2092 |
+
|
| 2093 |
+
def __getattr__(self, attr):
|
| 2094 |
+
if attr == self._EXECUTOR_ATTR:
|
| 2095 |
+
return self._trial_executor
|
| 2096 |
+
if attr not in self._runner_whitelist_attr:
|
| 2097 |
+
if log_once("restrict_accessing_tune_controller"):
|
| 2098 |
+
logger.warning(
|
| 2099 |
+
f"You are trying to access {attr} interface of "
|
| 2100 |
+
f"TrialRunner in TrialScheduler, which is being "
|
| 2101 |
+
f"restricted. If you believe it is reasonable for "
|
| 2102 |
+
f"your scheduler to access this TrialRunner API, "
|
| 2103 |
+
f"please reach out to Ray team on GitHub. A more "
|
| 2104 |
+
f"strict API access pattern would be enforced "
|
| 2105 |
+
f"starting 1.12s.0"
|
| 2106 |
+
)
|
| 2107 |
+
return getattr(self._tune_controller, attr)
|
| 2108 |
+
|
| 2109 |
+
|
| 2110 |
+
def _get_max_pending_trials(search_alg: SearchAlgorithm) -> int:
|
| 2111 |
+
max_pending_trials = os.getenv("TUNE_MAX_PENDING_TRIALS_PG", "auto")
|
| 2112 |
+
|
| 2113 |
+
if max_pending_trials != "auto":
|
| 2114 |
+
return int(max_pending_trials)
|
| 2115 |
+
|
| 2116 |
+
# Else, auto detect.
|
| 2117 |
+
|
| 2118 |
+
# Only BasicVariantGenerator supports > 1 pending trials.
|
| 2119 |
+
# This is because we don't want to generate too many trials
|
| 2120 |
+
# before we fit the searcher model.
|
| 2121 |
+
if not isinstance(search_alg, BasicVariantGenerator):
|
| 2122 |
+
return 1
|
| 2123 |
+
|
| 2124 |
+
# Allow up to at least 200 pending trials to trigger fast autoscaling
|
| 2125 |
+
min_autoscaling_rate = 200
|
| 2126 |
+
|
| 2127 |
+
# Allow more pending trials for larger clusters (based on number of CPUs)
|
| 2128 |
+
cluster_cpus = ray.cluster_resources().get("CPU", 1.0)
|
| 2129 |
+
max_pending_trials = max(min_autoscaling_rate, int(cluster_cpus * 1.1))
|
| 2130 |
+
|
| 2131 |
+
if max_pending_trials > min_autoscaling_rate:
|
| 2132 |
+
logger.warning(
|
| 2133 |
+
f"The maximum number of pending trials has been "
|
| 2134 |
+
f"automatically set to the number of available "
|
| 2135 |
+
f"cluster CPUs, which is high "
|
| 2136 |
+
f"({max_pending_trials} CPUs/pending trials). "
|
| 2137 |
+
f"If you're running an experiment with a large number "
|
| 2138 |
+
f"of trials, this could lead to scheduling overhead. "
|
| 2139 |
+
f"In this case, consider setting the "
|
| 2140 |
+
f"`TUNE_MAX_PENDING_TRIALS_PG` environment variable "
|
| 2141 |
+
f"to the desired maximum number of concurrent pending trials."
|
| 2142 |
+
)
|
| 2143 |
+
|
| 2144 |
+
return max_pending_trials
|
| 2145 |
+
|
| 2146 |
+
|
| 2147 |
+
class _FakeRayTrialExecutor:
|
| 2148 |
+
"""The TuneController does not use a RayTrialExecutor anymore.
|
| 2149 |
+
|
| 2150 |
+
Instead, we pass this fake executor for searchers/schedulers to use
|
| 2151 |
+
as an interface.
|
| 2152 |
+
|
| 2153 |
+
In the future, we should have the searchers/schedulers either interact with
|
| 2154 |
+
the tune controller, or define a different API for more fine-grained scheduler
|
| 2155 |
+
control.
|
| 2156 |
+
"""
|
| 2157 |
+
|
| 2158 |
+
def __init__(self, tune_controller: TuneController):
|
| 2159 |
+
self._tune_controller = tune_controller
|
| 2160 |
+
|
| 2161 |
+
def pause_trial(self, trial: Trial, should_checkpoint: bool = True):
|
| 2162 |
+
return self._tune_controller._schedule_trial_pause(
|
| 2163 |
+
trial, should_checkpoint=should_checkpoint
|
| 2164 |
+
)
|
| 2165 |
+
|
| 2166 |
+
def save(
|
| 2167 |
+
self,
|
| 2168 |
+
trial: Trial,
|
| 2169 |
+
result: Optional[Dict] = None,
|
| 2170 |
+
) -> Optional[_FutureTrainingResult]:
|
| 2171 |
+
return self._tune_controller._schedule_trial_save(trial=trial, result=result)
|
| 2172 |
+
|
| 2173 |
+
def has_resources_for_trial(self, trial: Trial):
|
| 2174 |
+
return True
|
| 2175 |
+
|
| 2176 |
+
@property
|
| 2177 |
+
def _resource_updater(self):
|
| 2178 |
+
return self._tune_controller._resource_updater
|
| 2179 |
+
|
| 2180 |
+
def force_reconcilation_on_next_step_end(self):
|
| 2181 |
+
pass
|
.venv/lib/python3.11/site-packages/ray/tune/experiment/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ray.tune.experiment.experiment import Experiment, _convert_to_experiment_list
|
| 2 |
+
from ray.tune.experiment.trial import Trial
|
| 3 |
+
|
| 4 |
+
__all__ = ["Experiment", "_convert_to_experiment_list", "Trial"]
|
.venv/lib/python3.11/site-packages/ray/tune/experiment/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (429 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/tune/experiment/__pycache__/config_parser.cpython-311.pyc
ADDED
|
Binary file (8.23 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/tune/experiment/__pycache__/experiment.cpython-311.pyc
ADDED
|
Binary file (20.6 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/tune/experiment/__pycache__/trial.cpython-311.pyc
ADDED
|
Binary file (51 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/tune/experiment/config_parser.py
ADDED
|
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import json
|
| 3 |
+
|
| 4 |
+
from ray.train import CheckpointConfig
|
| 5 |
+
from ray.tune.error import TuneError
|
| 6 |
+
from ray.tune.experiment import Trial
|
| 7 |
+
from ray.tune.resources import json_to_resources
|
| 8 |
+
|
| 9 |
+
# For compatibility under py2 to consider unicode as str
|
| 10 |
+
from ray.tune.utils.serialization import TuneFunctionEncoder
|
| 11 |
+
from ray.tune.utils.util import SafeFallbackEncoder
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def _make_parser(parser_creator=None, **kwargs):
|
| 15 |
+
"""Returns a base argument parser for the ray.tune tool.
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
parser_creator: A constructor for the parser class.
|
| 19 |
+
kwargs: Non-positional args to be passed into the
|
| 20 |
+
parser class constructor.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
if parser_creator:
|
| 24 |
+
parser = parser_creator(**kwargs)
|
| 25 |
+
else:
|
| 26 |
+
parser = argparse.ArgumentParser(**kwargs)
|
| 27 |
+
|
| 28 |
+
# Note: keep this in sync with rllib/train.py
|
| 29 |
+
parser.add_argument(
|
| 30 |
+
"--run",
|
| 31 |
+
default=None,
|
| 32 |
+
type=str,
|
| 33 |
+
help="The algorithm or model to train. This may refer to the name "
|
| 34 |
+
"of a built-on algorithm (e.g. RLlib's DQN or PPO), or a "
|
| 35 |
+
"user-defined trainable function or class registered in the "
|
| 36 |
+
"tune registry.",
|
| 37 |
+
)
|
| 38 |
+
parser.add_argument(
|
| 39 |
+
"--stop",
|
| 40 |
+
default="{}",
|
| 41 |
+
type=json.loads,
|
| 42 |
+
help="The stopping criteria, specified in JSON. The keys may be any "
|
| 43 |
+
"field returned by 'train()' e.g. "
|
| 44 |
+
'\'{"time_total_s": 600, "training_iteration": 100000}\' to stop '
|
| 45 |
+
"after 600 seconds or 100k iterations, whichever is reached first.",
|
| 46 |
+
)
|
| 47 |
+
parser.add_argument(
|
| 48 |
+
"--config",
|
| 49 |
+
default="{}",
|
| 50 |
+
type=json.loads,
|
| 51 |
+
help="Algorithm-specific configuration (e.g. env, hyperparams), "
|
| 52 |
+
"specified in JSON.",
|
| 53 |
+
)
|
| 54 |
+
parser.add_argument(
|
| 55 |
+
"--resources-per-trial",
|
| 56 |
+
default=None,
|
| 57 |
+
type=json_to_resources,
|
| 58 |
+
help="Override the machine resources to allocate per trial, e.g. "
|
| 59 |
+
'\'{"cpu": 64, "gpu": 8}\'. Note that GPUs will not be assigned '
|
| 60 |
+
"unless you specify them here. For RLlib, you probably want to "
|
| 61 |
+
"leave this alone and use RLlib configs to control parallelism.",
|
| 62 |
+
)
|
| 63 |
+
parser.add_argument(
|
| 64 |
+
"--num-samples",
|
| 65 |
+
default=1,
|
| 66 |
+
type=int,
|
| 67 |
+
help="Number of times to repeat each trial.",
|
| 68 |
+
)
|
| 69 |
+
parser.add_argument(
|
| 70 |
+
"--checkpoint-freq",
|
| 71 |
+
default=0,
|
| 72 |
+
type=int,
|
| 73 |
+
help="How many training iterations between checkpoints. "
|
| 74 |
+
"A value of 0 (default) disables checkpointing.",
|
| 75 |
+
)
|
| 76 |
+
parser.add_argument(
|
| 77 |
+
"--checkpoint-at-end",
|
| 78 |
+
action="store_true",
|
| 79 |
+
help="Whether to checkpoint at the end of the experiment. Default is False.",
|
| 80 |
+
)
|
| 81 |
+
parser.add_argument(
|
| 82 |
+
"--keep-checkpoints-num",
|
| 83 |
+
default=None,
|
| 84 |
+
type=int,
|
| 85 |
+
help="Number of best checkpoints to keep. Others get "
|
| 86 |
+
"deleted. Default (None) keeps all checkpoints.",
|
| 87 |
+
)
|
| 88 |
+
parser.add_argument(
|
| 89 |
+
"--checkpoint-score-attr",
|
| 90 |
+
default="training_iteration",
|
| 91 |
+
type=str,
|
| 92 |
+
help="Specifies by which attribute to rank the best checkpoint. "
|
| 93 |
+
"Default is increasing order. If attribute starts with min- it "
|
| 94 |
+
"will rank attribute in decreasing order. Example: "
|
| 95 |
+
"min-validation_loss",
|
| 96 |
+
)
|
| 97 |
+
parser.add_argument(
|
| 98 |
+
"--export-formats",
|
| 99 |
+
default=None,
|
| 100 |
+
help="List of formats that exported at the end of the experiment. "
|
| 101 |
+
"Default is None. For RLlib, 'checkpoint' and 'model' are "
|
| 102 |
+
"supported for TensorFlow policy graphs.",
|
| 103 |
+
)
|
| 104 |
+
parser.add_argument(
|
| 105 |
+
"--max-failures",
|
| 106 |
+
default=3,
|
| 107 |
+
type=int,
|
| 108 |
+
help="Try to recover a trial from its last checkpoint at least this "
|
| 109 |
+
"many times. Only applies if checkpointing is enabled.",
|
| 110 |
+
)
|
| 111 |
+
parser.add_argument(
|
| 112 |
+
"--scheduler",
|
| 113 |
+
default="FIFO",
|
| 114 |
+
type=str,
|
| 115 |
+
help="FIFO (default), MedianStopping, AsyncHyperBand, "
|
| 116 |
+
"HyperBand, or HyperOpt.",
|
| 117 |
+
)
|
| 118 |
+
parser.add_argument(
|
| 119 |
+
"--scheduler-config",
|
| 120 |
+
default="{}",
|
| 121 |
+
type=json.loads,
|
| 122 |
+
help="Config options to pass to the scheduler.",
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
# Note: this currently only makes sense when running a single trial
|
| 126 |
+
parser.add_argument(
|
| 127 |
+
"--restore",
|
| 128 |
+
default=None,
|
| 129 |
+
type=str,
|
| 130 |
+
help="If specified, restore from this checkpoint.",
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
return parser
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def _to_argv(config):
|
| 137 |
+
"""Converts configuration to a command line argument format."""
|
| 138 |
+
argv = []
|
| 139 |
+
for k, v in config.items():
|
| 140 |
+
if "-" in k:
|
| 141 |
+
raise ValueError("Use '_' instead of '-' in `{}`".format(k))
|
| 142 |
+
if v is None:
|
| 143 |
+
continue
|
| 144 |
+
if not isinstance(v, bool) or v: # for argparse flags
|
| 145 |
+
argv.append("--{}".format(k.replace("_", "-")))
|
| 146 |
+
if isinstance(v, str):
|
| 147 |
+
argv.append(v)
|
| 148 |
+
elif isinstance(v, bool):
|
| 149 |
+
pass
|
| 150 |
+
elif callable(v):
|
| 151 |
+
argv.append(json.dumps(v, cls=TuneFunctionEncoder))
|
| 152 |
+
else:
|
| 153 |
+
argv.append(json.dumps(v, cls=SafeFallbackEncoder))
|
| 154 |
+
return argv
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
_cached_pgf = {}
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def _create_trial_from_spec(
|
| 161 |
+
spec: dict, parser: argparse.ArgumentParser, **trial_kwargs
|
| 162 |
+
):
|
| 163 |
+
"""Creates a Trial object from parsing the spec.
|
| 164 |
+
|
| 165 |
+
Args:
|
| 166 |
+
spec: A resolved experiment specification. Arguments should
|
| 167 |
+
The args here should correspond to the command line flags
|
| 168 |
+
in ray.tune.experiment.config_parser.
|
| 169 |
+
parser: An argument parser object from
|
| 170 |
+
make_parser.
|
| 171 |
+
trial_kwargs: Extra keyword arguments used in instantiating the Trial.
|
| 172 |
+
|
| 173 |
+
Returns:
|
| 174 |
+
A trial object with corresponding parameters to the specification.
|
| 175 |
+
"""
|
| 176 |
+
global _cached_pgf
|
| 177 |
+
|
| 178 |
+
spec = spec.copy()
|
| 179 |
+
resources = spec.pop("resources_per_trial", None)
|
| 180 |
+
|
| 181 |
+
try:
|
| 182 |
+
args, _ = parser.parse_known_args(_to_argv(spec))
|
| 183 |
+
except SystemExit:
|
| 184 |
+
raise TuneError("Error parsing args, see above message", spec)
|
| 185 |
+
|
| 186 |
+
if resources:
|
| 187 |
+
trial_kwargs["placement_group_factory"] = resources
|
| 188 |
+
|
| 189 |
+
checkpoint_config = spec.get("checkpoint_config", CheckpointConfig())
|
| 190 |
+
|
| 191 |
+
return Trial(
|
| 192 |
+
# Submitting trial via server in py2.7 creates Unicode, which does not
|
| 193 |
+
# convert to string in a straightforward manner.
|
| 194 |
+
trainable_name=spec["run"],
|
| 195 |
+
# json.load leads to str -> unicode in py2.7
|
| 196 |
+
config=spec.get("config", {}),
|
| 197 |
+
# json.load leads to str -> unicode in py2.7
|
| 198 |
+
stopping_criterion=spec.get("stop", {}),
|
| 199 |
+
checkpoint_config=checkpoint_config,
|
| 200 |
+
export_formats=spec.get("export_formats", []),
|
| 201 |
+
# str(None) doesn't create None
|
| 202 |
+
restore_path=spec.get("restore"),
|
| 203 |
+
trial_name_creator=spec.get("trial_name_creator"),
|
| 204 |
+
trial_dirname_creator=spec.get("trial_dirname_creator"),
|
| 205 |
+
log_to_file=spec.get("log_to_file"),
|
| 206 |
+
# str(None) doesn't create None
|
| 207 |
+
max_failures=args.max_failures,
|
| 208 |
+
storage=spec.get("storage"),
|
| 209 |
+
**trial_kwargs,
|
| 210 |
+
)
|
.venv/lib/python3.11/site-packages/ray/tune/experiment/experiment.py
ADDED
|
@@ -0,0 +1,445 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import datetime
|
| 3 |
+
import logging
|
| 4 |
+
import pprint as pp
|
| 5 |
+
import traceback
|
| 6 |
+
from functools import partial
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from pickle import PicklingError
|
| 9 |
+
from typing import (
|
| 10 |
+
TYPE_CHECKING,
|
| 11 |
+
Any,
|
| 12 |
+
Callable,
|
| 13 |
+
Dict,
|
| 14 |
+
List,
|
| 15 |
+
Mapping,
|
| 16 |
+
Optional,
|
| 17 |
+
Sequence,
|
| 18 |
+
Type,
|
| 19 |
+
Union,
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
import ray
|
| 23 |
+
from ray.exceptions import RpcError
|
| 24 |
+
from ray.train import CheckpointConfig, SyncConfig
|
| 25 |
+
from ray.train._internal.storage import StorageContext
|
| 26 |
+
from ray.train.constants import DEFAULT_STORAGE_PATH
|
| 27 |
+
from ray.tune.error import TuneError
|
| 28 |
+
from ray.tune.registry import is_function_trainable, register_trainable
|
| 29 |
+
from ray.tune.stopper import CombinedStopper, FunctionStopper, Stopper, TimeoutStopper
|
| 30 |
+
from ray.util.annotations import Deprecated, DeveloperAPI
|
| 31 |
+
|
| 32 |
+
if TYPE_CHECKING:
|
| 33 |
+
import pyarrow.fs
|
| 34 |
+
|
| 35 |
+
from ray.tune import PlacementGroupFactory
|
| 36 |
+
from ray.tune.experiment import Trial
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
logger = logging.getLogger(__name__)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def _validate_log_to_file(log_to_file):
|
| 43 |
+
"""Validate ``train.RunConfig``'s ``log_to_file`` parameter. Return
|
| 44 |
+
validated relative stdout and stderr filenames."""
|
| 45 |
+
if not log_to_file:
|
| 46 |
+
stdout_file = stderr_file = None
|
| 47 |
+
elif isinstance(log_to_file, bool) and log_to_file:
|
| 48 |
+
stdout_file = "stdout"
|
| 49 |
+
stderr_file = "stderr"
|
| 50 |
+
elif isinstance(log_to_file, str):
|
| 51 |
+
stdout_file = stderr_file = log_to_file
|
| 52 |
+
elif isinstance(log_to_file, Sequence):
|
| 53 |
+
if len(log_to_file) != 2:
|
| 54 |
+
raise ValueError(
|
| 55 |
+
"If you pass a Sequence to `log_to_file` it has to have "
|
| 56 |
+
"a length of 2 (for stdout and stderr, respectively). The "
|
| 57 |
+
"Sequence you passed has length {}.".format(len(log_to_file))
|
| 58 |
+
)
|
| 59 |
+
stdout_file, stderr_file = log_to_file
|
| 60 |
+
else:
|
| 61 |
+
raise ValueError(
|
| 62 |
+
"You can pass a boolean, a string, or a Sequence of length 2 to "
|
| 63 |
+
"`log_to_file`, but you passed something else ({}).".format(
|
| 64 |
+
type(log_to_file)
|
| 65 |
+
)
|
| 66 |
+
)
|
| 67 |
+
return stdout_file, stderr_file
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
@DeveloperAPI
|
| 71 |
+
class Experiment:
|
| 72 |
+
"""Tracks experiment specifications.
|
| 73 |
+
|
| 74 |
+
Implicitly registers the Trainable if needed. The args here take
|
| 75 |
+
the same meaning as the arguments defined `tune.py:run`.
|
| 76 |
+
|
| 77 |
+
.. code-block:: python
|
| 78 |
+
|
| 79 |
+
experiment_spec = Experiment(
|
| 80 |
+
"my_experiment_name",
|
| 81 |
+
my_func,
|
| 82 |
+
stop={"mean_accuracy": 100},
|
| 83 |
+
config={
|
| 84 |
+
"alpha": tune.grid_search([0.2, 0.4, 0.6]),
|
| 85 |
+
"beta": tune.grid_search([1, 2]),
|
| 86 |
+
},
|
| 87 |
+
resources_per_trial={
|
| 88 |
+
"cpu": 1,
|
| 89 |
+
"gpu": 0
|
| 90 |
+
},
|
| 91 |
+
num_samples=10,
|
| 92 |
+
local_dir="~/ray_results",
|
| 93 |
+
checkpoint_freq=10,
|
| 94 |
+
max_failures=2)
|
| 95 |
+
|
| 96 |
+
"""
|
| 97 |
+
|
| 98 |
+
# Keys that will be present in `public_spec` dict.
|
| 99 |
+
PUBLIC_KEYS = {"stop", "num_samples", "time_budget_s"}
|
| 100 |
+
_storage_context_cls = StorageContext
|
| 101 |
+
|
| 102 |
+
def __init__(
|
| 103 |
+
self,
|
| 104 |
+
name: str,
|
| 105 |
+
run: Union[str, Callable, Type],
|
| 106 |
+
*,
|
| 107 |
+
stop: Optional[Union[Mapping, Stopper, Callable[[str, Mapping], bool]]] = None,
|
| 108 |
+
time_budget_s: Optional[Union[int, float, datetime.timedelta]] = None,
|
| 109 |
+
config: Optional[Dict[str, Any]] = None,
|
| 110 |
+
resources_per_trial: Union[
|
| 111 |
+
None, Mapping[str, Union[float, int, Mapping]], "PlacementGroupFactory"
|
| 112 |
+
] = None,
|
| 113 |
+
num_samples: int = 1,
|
| 114 |
+
storage_path: Optional[str] = None,
|
| 115 |
+
storage_filesystem: Optional["pyarrow.fs.FileSystem"] = None,
|
| 116 |
+
sync_config: Optional[Union[SyncConfig, dict]] = None,
|
| 117 |
+
checkpoint_config: Optional[Union[CheckpointConfig, dict]] = None,
|
| 118 |
+
trial_name_creator: Optional[Callable[["Trial"], str]] = None,
|
| 119 |
+
trial_dirname_creator: Optional[Callable[["Trial"], str]] = None,
|
| 120 |
+
log_to_file: bool = False,
|
| 121 |
+
export_formats: Optional[Sequence] = None,
|
| 122 |
+
max_failures: int = 0,
|
| 123 |
+
restore: Optional[str] = None,
|
| 124 |
+
# Deprecated
|
| 125 |
+
local_dir: Optional[str] = None,
|
| 126 |
+
):
|
| 127 |
+
if isinstance(checkpoint_config, dict):
|
| 128 |
+
checkpoint_config = CheckpointConfig(**checkpoint_config)
|
| 129 |
+
else:
|
| 130 |
+
checkpoint_config = checkpoint_config or CheckpointConfig()
|
| 131 |
+
|
| 132 |
+
if is_function_trainable(run):
|
| 133 |
+
if checkpoint_config.checkpoint_at_end:
|
| 134 |
+
raise ValueError(
|
| 135 |
+
"'checkpoint_at_end' cannot be used with a function trainable. "
|
| 136 |
+
"You should include one last call to "
|
| 137 |
+
"`ray.train.report(metrics=..., checkpoint=...)` "
|
| 138 |
+
"at the end of your training loop to get this behavior."
|
| 139 |
+
)
|
| 140 |
+
if checkpoint_config.checkpoint_frequency:
|
| 141 |
+
raise ValueError(
|
| 142 |
+
"'checkpoint_frequency' cannot be set for a function trainable. "
|
| 143 |
+
"You will need to report a checkpoint every "
|
| 144 |
+
"`checkpoint_frequency` iterations within your training loop using "
|
| 145 |
+
"`ray.train.report(metrics=..., checkpoint=...)` "
|
| 146 |
+
"to get this behavior."
|
| 147 |
+
)
|
| 148 |
+
try:
|
| 149 |
+
self._run_identifier = Experiment.register_if_needed(run)
|
| 150 |
+
except RpcError as e:
|
| 151 |
+
if e.rpc_code == ray._raylet.GRPC_STATUS_CODE_RESOURCE_EXHAUSTED:
|
| 152 |
+
raise TuneError(
|
| 153 |
+
f"The Trainable/training function is too large for grpc resource "
|
| 154 |
+
f"limit. Check that its definition is not implicitly capturing a "
|
| 155 |
+
f"large array or other object in scope. "
|
| 156 |
+
f"Tip: use tune.with_parameters() to put large objects "
|
| 157 |
+
f"in the Ray object store. \n"
|
| 158 |
+
f"Original exception: {traceback.format_exc()}"
|
| 159 |
+
)
|
| 160 |
+
else:
|
| 161 |
+
raise e
|
| 162 |
+
|
| 163 |
+
if not name:
|
| 164 |
+
name = StorageContext.get_experiment_dir_name(run)
|
| 165 |
+
|
| 166 |
+
storage_path = storage_path or DEFAULT_STORAGE_PATH
|
| 167 |
+
self.storage = self._storage_context_cls(
|
| 168 |
+
storage_path=storage_path,
|
| 169 |
+
storage_filesystem=storage_filesystem,
|
| 170 |
+
sync_config=sync_config,
|
| 171 |
+
experiment_dir_name=name,
|
| 172 |
+
)
|
| 173 |
+
logger.debug(f"StorageContext on the DRIVER:\n{self.storage}")
|
| 174 |
+
|
| 175 |
+
config = config or {}
|
| 176 |
+
if not isinstance(config, dict):
|
| 177 |
+
raise ValueError(
|
| 178 |
+
f"`Experiment(config)` must be a dict, got: {type(config)}. "
|
| 179 |
+
"Please convert your search space to a dict before passing it in."
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
self._stopper = None
|
| 183 |
+
stopping_criteria = {}
|
| 184 |
+
if not stop:
|
| 185 |
+
pass
|
| 186 |
+
elif isinstance(stop, list):
|
| 187 |
+
bad_stoppers = [s for s in stop if not isinstance(s, Stopper)]
|
| 188 |
+
if bad_stoppers:
|
| 189 |
+
stopper_types = [type(s) for s in stop]
|
| 190 |
+
raise ValueError(
|
| 191 |
+
"If you pass a list as the `stop` argument to "
|
| 192 |
+
"`train.RunConfig()`, each element must be an instance of "
|
| 193 |
+
f"`tune.stopper.Stopper`. Got {stopper_types}."
|
| 194 |
+
)
|
| 195 |
+
self._stopper = CombinedStopper(*stop)
|
| 196 |
+
elif isinstance(stop, dict):
|
| 197 |
+
stopping_criteria = stop
|
| 198 |
+
elif callable(stop):
|
| 199 |
+
if FunctionStopper.is_valid_function(stop):
|
| 200 |
+
self._stopper = FunctionStopper(stop)
|
| 201 |
+
elif isinstance(stop, Stopper):
|
| 202 |
+
self._stopper = stop
|
| 203 |
+
else:
|
| 204 |
+
raise ValueError(
|
| 205 |
+
"Provided stop object must be either a dict, "
|
| 206 |
+
"a function, or a subclass of "
|
| 207 |
+
f"`ray.tune.Stopper`. Got {type(stop)}."
|
| 208 |
+
)
|
| 209 |
+
else:
|
| 210 |
+
raise ValueError(
|
| 211 |
+
f"Invalid stop criteria: {stop}. Must be a "
|
| 212 |
+
f"callable or dict. Got {type(stop)}."
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
if time_budget_s:
|
| 216 |
+
if self._stopper:
|
| 217 |
+
self._stopper = CombinedStopper(
|
| 218 |
+
self._stopper, TimeoutStopper(time_budget_s)
|
| 219 |
+
)
|
| 220 |
+
else:
|
| 221 |
+
self._stopper = TimeoutStopper(time_budget_s)
|
| 222 |
+
|
| 223 |
+
stdout_file, stderr_file = _validate_log_to_file(log_to_file)
|
| 224 |
+
|
| 225 |
+
spec = {
|
| 226 |
+
"run": self._run_identifier,
|
| 227 |
+
"stop": stopping_criteria,
|
| 228 |
+
"time_budget_s": time_budget_s,
|
| 229 |
+
"config": config,
|
| 230 |
+
"resources_per_trial": resources_per_trial,
|
| 231 |
+
"num_samples": num_samples,
|
| 232 |
+
"checkpoint_config": checkpoint_config,
|
| 233 |
+
"trial_name_creator": trial_name_creator,
|
| 234 |
+
"trial_dirname_creator": trial_dirname_creator,
|
| 235 |
+
"log_to_file": (stdout_file, stderr_file),
|
| 236 |
+
"export_formats": export_formats or [],
|
| 237 |
+
"max_failures": max_failures,
|
| 238 |
+
"restore": (
|
| 239 |
+
Path(restore).expanduser().absolute().as_posix() if restore else None
|
| 240 |
+
),
|
| 241 |
+
"storage": self.storage,
|
| 242 |
+
}
|
| 243 |
+
self.spec = spec
|
| 244 |
+
|
| 245 |
+
@classmethod
|
| 246 |
+
def from_json(cls, name: str, spec: dict):
|
| 247 |
+
"""Generates an Experiment object from JSON.
|
| 248 |
+
|
| 249 |
+
Args:
|
| 250 |
+
name: Name of Experiment.
|
| 251 |
+
spec: JSON configuration of experiment.
|
| 252 |
+
"""
|
| 253 |
+
if "run" not in spec:
|
| 254 |
+
raise TuneError("No trainable specified!")
|
| 255 |
+
|
| 256 |
+
# Special case the `env` param for RLlib by automatically
|
| 257 |
+
# moving it into the `config` section.
|
| 258 |
+
if "env" in spec:
|
| 259 |
+
spec["config"] = spec.get("config", {})
|
| 260 |
+
spec["config"]["env"] = spec["env"]
|
| 261 |
+
del spec["env"]
|
| 262 |
+
|
| 263 |
+
if "sync_config" in spec and isinstance(spec["sync_config"], dict):
|
| 264 |
+
spec["sync_config"] = SyncConfig(**spec["sync_config"])
|
| 265 |
+
|
| 266 |
+
if "checkpoint_config" in spec and isinstance(spec["checkpoint_config"], dict):
|
| 267 |
+
spec["checkpoint_config"] = CheckpointConfig(**spec["checkpoint_config"])
|
| 268 |
+
|
| 269 |
+
spec = copy.deepcopy(spec)
|
| 270 |
+
|
| 271 |
+
run_value = spec.pop("run")
|
| 272 |
+
try:
|
| 273 |
+
exp = cls(name, run_value, **spec)
|
| 274 |
+
except TypeError as e:
|
| 275 |
+
raise TuneError(
|
| 276 |
+
f"Failed to load the following Tune experiment "
|
| 277 |
+
f"specification:\n\n {pp.pformat(spec)}.\n\n"
|
| 278 |
+
f"Please check that the arguments are valid. "
|
| 279 |
+
f"Experiment creation failed with the following "
|
| 280 |
+
f"error:\n {e}"
|
| 281 |
+
)
|
| 282 |
+
return exp
|
| 283 |
+
|
| 284 |
+
@classmethod
|
| 285 |
+
def get_trainable_name(cls, run_object: Union[str, Callable, Type]):
|
| 286 |
+
"""Get Trainable name.
|
| 287 |
+
|
| 288 |
+
Args:
|
| 289 |
+
run_object: Trainable to run. If string,
|
| 290 |
+
assumes it is an ID and does not modify it. Otherwise,
|
| 291 |
+
returns a string corresponding to the run_object name.
|
| 292 |
+
|
| 293 |
+
Returns:
|
| 294 |
+
A string representing the trainable identifier.
|
| 295 |
+
|
| 296 |
+
Raises:
|
| 297 |
+
TuneError: if ``run_object`` passed in is invalid.
|
| 298 |
+
"""
|
| 299 |
+
from ray.tune.search.sample import Domain
|
| 300 |
+
|
| 301 |
+
if isinstance(run_object, str) or isinstance(run_object, Domain):
|
| 302 |
+
return run_object
|
| 303 |
+
elif isinstance(run_object, type) or callable(run_object):
|
| 304 |
+
name = "DEFAULT"
|
| 305 |
+
if hasattr(run_object, "_name"):
|
| 306 |
+
name = run_object._name
|
| 307 |
+
elif hasattr(run_object, "__name__"):
|
| 308 |
+
fn_name = run_object.__name__
|
| 309 |
+
if fn_name == "<lambda>":
|
| 310 |
+
name = "lambda"
|
| 311 |
+
elif fn_name.startswith("<"):
|
| 312 |
+
name = "DEFAULT"
|
| 313 |
+
else:
|
| 314 |
+
name = fn_name
|
| 315 |
+
elif (
|
| 316 |
+
isinstance(run_object, partial)
|
| 317 |
+
and hasattr(run_object, "func")
|
| 318 |
+
and hasattr(run_object.func, "__name__")
|
| 319 |
+
):
|
| 320 |
+
name = run_object.func.__name__
|
| 321 |
+
else:
|
| 322 |
+
logger.warning("No name detected on trainable. Using {}.".format(name))
|
| 323 |
+
return name
|
| 324 |
+
else:
|
| 325 |
+
raise TuneError("Improper 'run' - not string nor trainable.")
|
| 326 |
+
|
| 327 |
+
@classmethod
|
| 328 |
+
def register_if_needed(cls, run_object: Union[str, Callable, Type]):
|
| 329 |
+
"""Registers Trainable or Function at runtime.
|
| 330 |
+
|
| 331 |
+
Assumes already registered if run_object is a string.
|
| 332 |
+
Also, does not inspect interface of given run_object.
|
| 333 |
+
|
| 334 |
+
Args:
|
| 335 |
+
run_object: Trainable to run. If string,
|
| 336 |
+
assumes it is an ID and does not modify it. Otherwise,
|
| 337 |
+
returns a string corresponding to the run_object name.
|
| 338 |
+
|
| 339 |
+
Returns:
|
| 340 |
+
A string representing the trainable identifier.
|
| 341 |
+
"""
|
| 342 |
+
from ray.tune.search.sample import Domain
|
| 343 |
+
|
| 344 |
+
if isinstance(run_object, str):
|
| 345 |
+
return run_object
|
| 346 |
+
elif isinstance(run_object, Domain):
|
| 347 |
+
logger.warning("Not registering trainable. Resolving as variant.")
|
| 348 |
+
return run_object
|
| 349 |
+
name = cls.get_trainable_name(run_object)
|
| 350 |
+
try:
|
| 351 |
+
register_trainable(name, run_object)
|
| 352 |
+
except (TypeError, PicklingError) as e:
|
| 353 |
+
extra_msg = (
|
| 354 |
+
"Other options: "
|
| 355 |
+
"\n-Try reproducing the issue by calling "
|
| 356 |
+
"`pickle.dumps(trainable)`. "
|
| 357 |
+
"\n-If the error is typing-related, try removing "
|
| 358 |
+
"the type annotations and try again."
|
| 359 |
+
)
|
| 360 |
+
raise type(e)(str(e) + " " + extra_msg) from None
|
| 361 |
+
return name
|
| 362 |
+
|
| 363 |
+
@property
|
| 364 |
+
def stopper(self):
|
| 365 |
+
return self._stopper
|
| 366 |
+
|
| 367 |
+
@property
|
| 368 |
+
def local_path(self) -> Optional[str]:
|
| 369 |
+
return self.storage.experiment_driver_staging_path
|
| 370 |
+
|
| 371 |
+
@property
|
| 372 |
+
@Deprecated("Replaced by `local_path`")
|
| 373 |
+
def local_dir(self):
|
| 374 |
+
# TODO(justinvyu): [Deprecated] Remove in 2.11.
|
| 375 |
+
raise DeprecationWarning("Use `local_path` instead of `local_dir`.")
|
| 376 |
+
|
| 377 |
+
@property
|
| 378 |
+
def remote_path(self) -> Optional[str]:
|
| 379 |
+
return self.storage.experiment_fs_path
|
| 380 |
+
|
| 381 |
+
@property
|
| 382 |
+
def path(self) -> Optional[str]:
|
| 383 |
+
return self.remote_path or self.local_path
|
| 384 |
+
|
| 385 |
+
@property
|
| 386 |
+
def checkpoint_config(self):
|
| 387 |
+
return self.spec.get("checkpoint_config")
|
| 388 |
+
|
| 389 |
+
@property
|
| 390 |
+
@Deprecated("Replaced by `local_path`")
|
| 391 |
+
def checkpoint_dir(self):
|
| 392 |
+
# TODO(justinvyu): [Deprecated] Remove in 2.11.
|
| 393 |
+
raise DeprecationWarning("Use `local_path` instead of `checkpoint_dir`.")
|
| 394 |
+
|
| 395 |
+
@property
|
| 396 |
+
def run_identifier(self):
|
| 397 |
+
"""Returns a string representing the trainable identifier."""
|
| 398 |
+
return self._run_identifier
|
| 399 |
+
|
| 400 |
+
@property
|
| 401 |
+
def public_spec(self) -> Dict[str, Any]:
|
| 402 |
+
"""Returns the spec dict with only the public-facing keys.
|
| 403 |
+
|
| 404 |
+
Intended to be used for passing information to callbacks,
|
| 405 |
+
Searchers and Schedulers.
|
| 406 |
+
"""
|
| 407 |
+
return {k: v for k, v in self.spec.items() if k in self.PUBLIC_KEYS}
|
| 408 |
+
|
| 409 |
+
|
| 410 |
+
def _convert_to_experiment_list(experiments: Union[Experiment, List[Experiment], Dict]):
|
| 411 |
+
"""Produces a list of Experiment objects.
|
| 412 |
+
|
| 413 |
+
Converts input from dict, single experiment, or list of
|
| 414 |
+
experiments to list of experiments. If input is None,
|
| 415 |
+
will return an empty list.
|
| 416 |
+
|
| 417 |
+
Arguments:
|
| 418 |
+
experiments: Experiments to run.
|
| 419 |
+
|
| 420 |
+
Returns:
|
| 421 |
+
List of experiments.
|
| 422 |
+
"""
|
| 423 |
+
exp_list = experiments
|
| 424 |
+
|
| 425 |
+
# Transform list if necessary
|
| 426 |
+
if experiments is None:
|
| 427 |
+
exp_list = []
|
| 428 |
+
elif isinstance(experiments, Experiment):
|
| 429 |
+
exp_list = [experiments]
|
| 430 |
+
elif type(experiments) is dict:
|
| 431 |
+
exp_list = [
|
| 432 |
+
Experiment.from_json(name, spec) for name, spec in experiments.items()
|
| 433 |
+
]
|
| 434 |
+
|
| 435 |
+
# Validate exp_list
|
| 436 |
+
if type(exp_list) is list and all(isinstance(exp, Experiment) for exp in exp_list):
|
| 437 |
+
if len(exp_list) > 1:
|
| 438 |
+
logger.info(
|
| 439 |
+
"Running with multiple concurrent experiments. "
|
| 440 |
+
"All experiments will be using the same SearchAlgorithm."
|
| 441 |
+
)
|
| 442 |
+
else:
|
| 443 |
+
raise TuneError("Invalid argument: {}".format(experiments))
|
| 444 |
+
|
| 445 |
+
return exp_list
|
.venv/lib/python3.11/site-packages/ray/tune/experiment/trial.py
ADDED
|
@@ -0,0 +1,1073 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import json
|
| 3 |
+
import logging
|
| 4 |
+
import os
|
| 5 |
+
import platform
|
| 6 |
+
import re
|
| 7 |
+
import time
|
| 8 |
+
import uuid
|
| 9 |
+
from contextlib import contextmanager
|
| 10 |
+
from functools import partial
|
| 11 |
+
from numbers import Number
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
|
| 14 |
+
|
| 15 |
+
import ray
|
| 16 |
+
import ray.cloudpickle as cloudpickle
|
| 17 |
+
from ray._private.utils import binary_to_hex, hex_to_binary
|
| 18 |
+
from ray.air.constants import (
|
| 19 |
+
EXPR_ERROR_FILE,
|
| 20 |
+
EXPR_ERROR_PICKLE_FILE,
|
| 21 |
+
TRAINING_ITERATION,
|
| 22 |
+
)
|
| 23 |
+
from ray.exceptions import RayActorError, RayTaskError
|
| 24 |
+
from ray.train import Checkpoint, CheckpointConfig
|
| 25 |
+
from ray.train._internal.checkpoint_manager import _CheckpointManager
|
| 26 |
+
from ray.train._internal.session import _FutureTrainingResult, _TrainingResult
|
| 27 |
+
from ray.train._internal.storage import StorageContext, _exists_at_fs_path
|
| 28 |
+
from ray.train.constants import (
|
| 29 |
+
RAY_CHDIR_TO_TRIAL_DIR,
|
| 30 |
+
RAY_TRAIN_COUNT_PREEMPTION_AS_FAILURE,
|
| 31 |
+
)
|
| 32 |
+
from ray.tune.error import TuneError
|
| 33 |
+
from ray.tune.execution.placement_groups import (
|
| 34 |
+
PlacementGroupFactory,
|
| 35 |
+
resource_dict_to_pg_factory,
|
| 36 |
+
)
|
| 37 |
+
from ray.tune.logger import NoopLogger
|
| 38 |
+
|
| 39 |
+
# NOTE(rkn): We import ray.tune.registry here instead of importing the names we
|
| 40 |
+
# need because there are cyclic imports that may cause specific names to not
|
| 41 |
+
# have been defined yet. See https://github.com/ray-project/ray/issues/1716.
|
| 42 |
+
from ray.tune.registry import get_trainable_cls, validate_trainable
|
| 43 |
+
from ray.tune.result import (
|
| 44 |
+
DEBUG_METRICS,
|
| 45 |
+
DONE,
|
| 46 |
+
NODE_IP,
|
| 47 |
+
PID,
|
| 48 |
+
STDERR_FILE,
|
| 49 |
+
STDOUT_FILE,
|
| 50 |
+
TRIAL_ID,
|
| 51 |
+
TRIAL_INFO,
|
| 52 |
+
)
|
| 53 |
+
from ray.tune.trainable.metadata import _TrainingRunMetadata
|
| 54 |
+
from ray.tune.utils import date_str, flatten_dict
|
| 55 |
+
from ray.tune.utils.serialization import TuneFunctionDecoder, TuneFunctionEncoder
|
| 56 |
+
from ray.util import log_once
|
| 57 |
+
from ray.util.annotations import Deprecated, DeveloperAPI
|
| 58 |
+
|
| 59 |
+
DEBUG_PRINT_INTERVAL = 5
|
| 60 |
+
_DEFAULT_WIN_MAX_PATH_LENGTH = 260
|
| 61 |
+
TRIAL_STATE_FILENAME = "trial_metadata.json"
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
logger = logging.getLogger(__name__)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class _Location:
|
| 68 |
+
"""Describes the location at which Trial is placed to run."""
|
| 69 |
+
|
| 70 |
+
def __init__(self, hostname=None, pid=None):
|
| 71 |
+
self.hostname = hostname
|
| 72 |
+
self.pid = pid
|
| 73 |
+
|
| 74 |
+
def __str__(self):
|
| 75 |
+
if not self.pid:
|
| 76 |
+
return ""
|
| 77 |
+
elif self.hostname == platform.node():
|
| 78 |
+
return "pid={}".format(self.pid)
|
| 79 |
+
else:
|
| 80 |
+
return "{}:{}".format(self.hostname, self.pid)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
@DeveloperAPI
|
| 84 |
+
class ExportFormat:
|
| 85 |
+
"""Describes the format to import/export the trial Trainable.
|
| 86 |
+
|
| 87 |
+
This may correspond to different file formats based on the
|
| 88 |
+
Trainable implementation.
|
| 89 |
+
"""
|
| 90 |
+
|
| 91 |
+
CHECKPOINT = "checkpoint"
|
| 92 |
+
MODEL = "model"
|
| 93 |
+
ONNX = "onnx"
|
| 94 |
+
H5 = "h5"
|
| 95 |
+
|
| 96 |
+
@staticmethod
|
| 97 |
+
def validate(formats):
|
| 98 |
+
"""Validates formats.
|
| 99 |
+
|
| 100 |
+
Raises:
|
| 101 |
+
ValueError: if the format is unknown.
|
| 102 |
+
"""
|
| 103 |
+
for i in range(len(formats)):
|
| 104 |
+
formats[i] = formats[i].strip().lower()
|
| 105 |
+
if formats[i] not in [
|
| 106 |
+
ExportFormat.CHECKPOINT,
|
| 107 |
+
ExportFormat.MODEL,
|
| 108 |
+
ExportFormat.ONNX,
|
| 109 |
+
ExportFormat.H5,
|
| 110 |
+
]:
|
| 111 |
+
raise TuneError("Unsupported import/export format: " + formats[i])
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
class _TrialInfo:
|
| 115 |
+
"""Serializable struct for holding information for a Trial.
|
| 116 |
+
|
| 117 |
+
Attributes:
|
| 118 |
+
trial_name: String name of the current trial.
|
| 119 |
+
trial_id: trial_id of the trial
|
| 120 |
+
trial_resources: resources used by trial.
|
| 121 |
+
"""
|
| 122 |
+
|
| 123 |
+
def __init__(self, trial: "Trial"):
|
| 124 |
+
self._trial_name = str(trial)
|
| 125 |
+
self._trial_id = trial.trial_id
|
| 126 |
+
self._trial_resources = trial.placement_group_factory
|
| 127 |
+
self._experiment_name = trial.experiment_dir_name
|
| 128 |
+
|
| 129 |
+
@property
|
| 130 |
+
def experiment_name(self):
|
| 131 |
+
return self._experiment_name
|
| 132 |
+
|
| 133 |
+
@property
|
| 134 |
+
def trial_name(self):
|
| 135 |
+
return self._trial_name
|
| 136 |
+
|
| 137 |
+
@property
|
| 138 |
+
def trial_id(self):
|
| 139 |
+
return self._trial_id
|
| 140 |
+
|
| 141 |
+
@property
|
| 142 |
+
def trial_resources(self) -> PlacementGroupFactory:
|
| 143 |
+
return self._trial_resources
|
| 144 |
+
|
| 145 |
+
@trial_resources.setter
|
| 146 |
+
def trial_resources(self, new_resources: PlacementGroupFactory):
|
| 147 |
+
self._trial_resources = new_resources
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
class _TemporaryTrialState:
|
| 151 |
+
"""Temporary trial state.
|
| 152 |
+
|
| 153 |
+
Values saved here should not be restored on resume.
|
| 154 |
+
"""
|
| 155 |
+
|
| 156 |
+
def __init__(self):
|
| 157 |
+
self.location = _Location()
|
| 158 |
+
|
| 159 |
+
self.ray_actor: Optional[ray.actor.ActorHandle] = None
|
| 160 |
+
|
| 161 |
+
self.saving_to: Optional[_FutureTrainingResult] = None
|
| 162 |
+
self.restoring_from: Optional[_TrainingResult] = None
|
| 163 |
+
|
| 164 |
+
self.num_restore_failures: int = 0
|
| 165 |
+
|
| 166 |
+
def __getstate__(self):
|
| 167 |
+
return {}
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def _get_max_path_length() -> int:
|
| 171 |
+
if hasattr(os, "pathconf"):
|
| 172 |
+
return os.pathconf("/", "PC_PATH_MAX")
|
| 173 |
+
# Windows
|
| 174 |
+
return _DEFAULT_WIN_MAX_PATH_LENGTH
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def _create_unique_logdir_name(root: str, relative_logdir: str) -> str:
|
| 178 |
+
candidate = Path(root).expanduser().joinpath(relative_logdir)
|
| 179 |
+
if candidate.exists():
|
| 180 |
+
relative_logdir_old = relative_logdir
|
| 181 |
+
relative_logdir += "_" + uuid.uuid4().hex[:4]
|
| 182 |
+
logger.info(
|
| 183 |
+
f"Creating a new dirname {relative_logdir} because "
|
| 184 |
+
f"trial dirname '{relative_logdir_old}' already exists."
|
| 185 |
+
)
|
| 186 |
+
return relative_logdir
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def _noop_logger_creator(config: Dict[str, Any], logdir: str):
|
| 190 |
+
# Upon remote process setup, record the actor's original working dir before
|
| 191 |
+
# changing to the Tune logdir
|
| 192 |
+
os.environ.setdefault("TUNE_ORIG_WORKING_DIR", os.getcwd())
|
| 193 |
+
|
| 194 |
+
os.makedirs(logdir, exist_ok=True)
|
| 195 |
+
|
| 196 |
+
if bool(int(os.environ.get(RAY_CHDIR_TO_TRIAL_DIR, "1"))):
|
| 197 |
+
# Set the working dir to the trial directory in the remote process,
|
| 198 |
+
# for user file writes
|
| 199 |
+
if not ray._private.worker._mode() == ray._private.worker.LOCAL_MODE:
|
| 200 |
+
os.chdir(logdir)
|
| 201 |
+
|
| 202 |
+
return NoopLogger(config, logdir)
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
def _get_trainable_kwargs(trial: "Trial") -> Dict[str, Any]:
|
| 206 |
+
trial.init_local_path()
|
| 207 |
+
|
| 208 |
+
logger_creator = partial(
|
| 209 |
+
_noop_logger_creator, logdir=trial.storage.trial_working_directory
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
trial_config = copy.deepcopy(trial.config)
|
| 213 |
+
trial_config[TRIAL_INFO] = _TrialInfo(trial)
|
| 214 |
+
stdout_file, stderr_file = trial.log_to_file
|
| 215 |
+
trial_config[STDOUT_FILE] = stdout_file
|
| 216 |
+
trial_config[STDERR_FILE] = stderr_file
|
| 217 |
+
|
| 218 |
+
assert trial.storage.trial_dir_name
|
| 219 |
+
|
| 220 |
+
kwargs = {
|
| 221 |
+
"config": trial_config,
|
| 222 |
+
"logger_creator": logger_creator,
|
| 223 |
+
"storage": trial.storage,
|
| 224 |
+
}
|
| 225 |
+
|
| 226 |
+
return kwargs
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
@contextmanager
|
| 230 |
+
def _change_working_directory(trial):
|
| 231 |
+
"""Context manager changing working directory to trial logdir.
|
| 232 |
+
Used in local mode.
|
| 233 |
+
|
| 234 |
+
For non-local mode it is no-op.
|
| 235 |
+
"""
|
| 236 |
+
if ray._private.worker._mode() == ray._private.worker.LOCAL_MODE:
|
| 237 |
+
old_dir = os.getcwd()
|
| 238 |
+
try:
|
| 239 |
+
os.chdir(trial.local_path)
|
| 240 |
+
yield
|
| 241 |
+
finally:
|
| 242 |
+
os.chdir(old_dir)
|
| 243 |
+
else:
|
| 244 |
+
yield
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
@DeveloperAPI
|
| 248 |
+
class Trial:
|
| 249 |
+
"""A trial object holds the state for one model training run.
|
| 250 |
+
|
| 251 |
+
Trials are themselves managed by the TrialRunner class, which implements
|
| 252 |
+
the event loop for submitting trial runs to a Ray cluster.
|
| 253 |
+
|
| 254 |
+
Trials start in the PENDING state, and transition to RUNNING once started.
|
| 255 |
+
On error, it transitions to ERROR, otherwise TERMINATED on success.
|
| 256 |
+
|
| 257 |
+
There are resources allocated to each trial. These should be specified
|
| 258 |
+
using ``PlacementGroupFactory``.
|
| 259 |
+
|
| 260 |
+
Attributes:
|
| 261 |
+
trainable_name: Name of the trainable object to be executed.
|
| 262 |
+
config: Provided configuration dictionary with evaluated params.
|
| 263 |
+
trial_id: Unique identifier for the trial.
|
| 264 |
+
path: Path where results for this trial are stored. Can be on
|
| 265 |
+
the local node or on cloud storage.
|
| 266 |
+
local_path: Path on the local disk where results are stored.
|
| 267 |
+
remote_path: Path on cloud storage where results are stored,
|
| 268 |
+
or None if not set.
|
| 269 |
+
relative_logdir: Directory of the trial relative to its
|
| 270 |
+
experiment directory.
|
| 271 |
+
evaluated_params: Evaluated parameters by search algorithm,
|
| 272 |
+
experiment_tag: Identifying trial name to show in the console
|
| 273 |
+
status: One of PENDING, RUNNING, PAUSED, TERMINATED, ERROR/
|
| 274 |
+
error_file: Path to the errors that this trial has raised.
|
| 275 |
+
|
| 276 |
+
"""
|
| 277 |
+
|
| 278 |
+
_nonjson_fields = [
|
| 279 |
+
"results",
|
| 280 |
+
"extra_arg",
|
| 281 |
+
"placement_group_factory",
|
| 282 |
+
"_resources",
|
| 283 |
+
"_default_placement_group_factory",
|
| 284 |
+
]
|
| 285 |
+
|
| 286 |
+
PENDING = "PENDING"
|
| 287 |
+
RUNNING = "RUNNING"
|
| 288 |
+
PAUSED = "PAUSED"
|
| 289 |
+
TERMINATED = "TERMINATED"
|
| 290 |
+
ERROR = "ERROR"
|
| 291 |
+
|
| 292 |
+
def __init__(
|
| 293 |
+
self,
|
| 294 |
+
trainable_name: str,
|
| 295 |
+
*,
|
| 296 |
+
config: Optional[Dict] = None,
|
| 297 |
+
trial_id: Optional[str] = None,
|
| 298 |
+
storage: Optional[StorageContext] = None,
|
| 299 |
+
evaluated_params: Optional[Dict] = None,
|
| 300 |
+
experiment_tag: str = "",
|
| 301 |
+
placement_group_factory: Optional[PlacementGroupFactory] = None,
|
| 302 |
+
stopping_criterion: Optional[Dict[str, float]] = None,
|
| 303 |
+
checkpoint_config: Optional[CheckpointConfig] = None,
|
| 304 |
+
export_formats: Optional[List[str]] = None,
|
| 305 |
+
restore_path: Optional[str] = None,
|
| 306 |
+
trial_name_creator: Optional[Callable[["Trial"], str]] = None,
|
| 307 |
+
trial_dirname_creator: Optional[Callable[["Trial"], str]] = None,
|
| 308 |
+
log_to_file: Union[Optional[str], Tuple[Optional[str], Optional[str]]] = None,
|
| 309 |
+
max_failures: int = 0,
|
| 310 |
+
stub: bool = False,
|
| 311 |
+
_setup_default_resource: bool = True,
|
| 312 |
+
):
|
| 313 |
+
"""Initialize a new trial.
|
| 314 |
+
|
| 315 |
+
The args here take the same meaning as the command line flags defined
|
| 316 |
+
in ray.tune.experiment.config_parser.
|
| 317 |
+
|
| 318 |
+
Args:
|
| 319 |
+
_setup_default_resource: Whether to set up default resources.
|
| 320 |
+
When initializing trials from checkpoints, this field is set to false,
|
| 321 |
+
so that setting up default resources can be delayed till after
|
| 322 |
+
``trial.config`` is loaded from checkpoints.
|
| 323 |
+
"""
|
| 324 |
+
# If this is set, trainables are not validated or looked up.
|
| 325 |
+
# This can be used e.g. to initialize Trial objects from checkpoints
|
| 326 |
+
# without loading the trainable first.
|
| 327 |
+
self.stub = stub
|
| 328 |
+
|
| 329 |
+
if not self.stub:
|
| 330 |
+
validate_trainable(trainable_name)
|
| 331 |
+
# Trial config
|
| 332 |
+
self.trainable_name = trainable_name
|
| 333 |
+
self.trial_id = Trial.generate_id() if trial_id is None else trial_id
|
| 334 |
+
|
| 335 |
+
self.temporary_state = _TemporaryTrialState()
|
| 336 |
+
self.run_metadata = _TrainingRunMetadata()
|
| 337 |
+
|
| 338 |
+
# Create a copy, since `init_local_path` updates the context with the
|
| 339 |
+
# generated trial dirname.
|
| 340 |
+
self.storage = copy.copy(storage)
|
| 341 |
+
|
| 342 |
+
self.config = config or {}
|
| 343 |
+
# Save a copy of the original unresolved config so that we can swap
|
| 344 |
+
# out and update any reference config values after restoration.
|
| 345 |
+
self.__unresolved_config = self.config
|
| 346 |
+
|
| 347 |
+
# Parameters that Tune varies across searches.
|
| 348 |
+
self.evaluated_params = evaluated_params or {}
|
| 349 |
+
self.experiment_tag = experiment_tag
|
| 350 |
+
self.stopping_criterion = stopping_criterion or {}
|
| 351 |
+
|
| 352 |
+
self._setup_default_resource = _setup_default_resource
|
| 353 |
+
|
| 354 |
+
if placement_group_factory and not isinstance(
|
| 355 |
+
placement_group_factory, PlacementGroupFactory
|
| 356 |
+
):
|
| 357 |
+
placement_group_factory = resource_dict_to_pg_factory(
|
| 358 |
+
placement_group_factory
|
| 359 |
+
)
|
| 360 |
+
|
| 361 |
+
self._default_placement_group_factory = placement_group_factory
|
| 362 |
+
# Will be created in create_placement_group_factory().
|
| 363 |
+
self.placement_group_factory = None
|
| 364 |
+
|
| 365 |
+
self.log_to_file = log_to_file
|
| 366 |
+
# Make sure `stdout_file, stderr_file = Trial.log_to_file` works
|
| 367 |
+
if (
|
| 368 |
+
not self.log_to_file
|
| 369 |
+
or not isinstance(self.log_to_file, Sequence)
|
| 370 |
+
or not len(self.log_to_file) == 2
|
| 371 |
+
):
|
| 372 |
+
self.log_to_file = (None, None)
|
| 373 |
+
|
| 374 |
+
self.max_failures = max_failures
|
| 375 |
+
|
| 376 |
+
# Local trial state that is updated during the run
|
| 377 |
+
self._default_result_or_future: Union[ray.ObjectRef, dict, None] = None
|
| 378 |
+
|
| 379 |
+
self.export_formats = export_formats
|
| 380 |
+
self.status = Trial.PENDING
|
| 381 |
+
self.relative_logdir = None
|
| 382 |
+
|
| 383 |
+
self.trial_name_creator = trial_name_creator
|
| 384 |
+
self.trial_dirname_creator = trial_dirname_creator
|
| 385 |
+
self.custom_trial_name = None
|
| 386 |
+
self.custom_dirname = None
|
| 387 |
+
|
| 388 |
+
# Checkpoint config
|
| 389 |
+
checkpoint_config = checkpoint_config or CheckpointConfig()
|
| 390 |
+
|
| 391 |
+
self.run_metadata.checkpoint_manager = _CheckpointManager(
|
| 392 |
+
checkpoint_config=checkpoint_config
|
| 393 |
+
)
|
| 394 |
+
|
| 395 |
+
# Restoration fields
|
| 396 |
+
self.restore_path = restore_path
|
| 397 |
+
self._restore_checkpoint_result: Optional[_TrainingResult] = None
|
| 398 |
+
if restore_path:
|
| 399 |
+
# tune.run(restore) passes in a path without metrics.
|
| 400 |
+
self._restore_checkpoint_result = _TrainingResult(
|
| 401 |
+
checkpoint=Checkpoint.from_directory(restore_path), metrics={}
|
| 402 |
+
)
|
| 403 |
+
|
| 404 |
+
if trial_name_creator:
|
| 405 |
+
self.custom_trial_name = trial_name_creator(self)
|
| 406 |
+
|
| 407 |
+
if trial_dirname_creator:
|
| 408 |
+
self.custom_dirname = trial_dirname_creator(self)
|
| 409 |
+
if os.path.sep in self.custom_dirname:
|
| 410 |
+
raise ValueError(
|
| 411 |
+
f"Trial dirname must not contain '/'. Got {self.custom_dirname}"
|
| 412 |
+
)
|
| 413 |
+
|
| 414 |
+
self._state_json = None
|
| 415 |
+
|
| 416 |
+
def create_placement_group_factory(self):
|
| 417 |
+
"""Compute placement group factory if needed.
|
| 418 |
+
|
| 419 |
+
Note: this must be called after all the placeholders in
|
| 420 |
+
self.config are resolved.
|
| 421 |
+
"""
|
| 422 |
+
trainable_cls = self.get_trainable_cls()
|
| 423 |
+
if not trainable_cls or not self._setup_default_resource:
|
| 424 |
+
# Create placement group factory using default resources.
|
| 425 |
+
self.placement_group_factory = (
|
| 426 |
+
self._default_placement_group_factory or resource_dict_to_pg_factory()
|
| 427 |
+
)
|
| 428 |
+
return
|
| 429 |
+
|
| 430 |
+
default_resources = trainable_cls.default_resource_request(self.config)
|
| 431 |
+
|
| 432 |
+
# If Trainable returns resources, do not allow manual override via
|
| 433 |
+
# `resources_per_trial` by the user.
|
| 434 |
+
if default_resources and self._default_placement_group_factory:
|
| 435 |
+
raise TuneError(
|
| 436 |
+
"Resources for {} have been automatically set to {} "
|
| 437 |
+
"by its `default_resource_request()` method. Please "
|
| 438 |
+
"clear the `resources_per_trial` option.".format(
|
| 439 |
+
trainable_cls, default_resources
|
| 440 |
+
)
|
| 441 |
+
)
|
| 442 |
+
|
| 443 |
+
if default_resources and not isinstance(
|
| 444 |
+
default_resources, PlacementGroupFactory
|
| 445 |
+
):
|
| 446 |
+
default_resources = resource_dict_to_pg_factory(default_resources)
|
| 447 |
+
|
| 448 |
+
self.placement_group_factory = (
|
| 449 |
+
# default_resource_request
|
| 450 |
+
default_resources
|
| 451 |
+
# resources_per_trial
|
| 452 |
+
or self._default_placement_group_factory
|
| 453 |
+
# cpu=1
|
| 454 |
+
or resource_dict_to_pg_factory()
|
| 455 |
+
)
|
| 456 |
+
|
| 457 |
+
def _get_default_result_or_future(self) -> Optional[dict]:
|
| 458 |
+
"""Calls ray.get on self._default_result_or_future and assigns back.
|
| 459 |
+
|
| 460 |
+
Returns None in case of exceptions.
|
| 461 |
+
Will also set the trial location if runner is set.
|
| 462 |
+
"""
|
| 463 |
+
if self._default_result_or_future and isinstance(
|
| 464 |
+
self._default_result_or_future, ray.ObjectRef
|
| 465 |
+
):
|
| 466 |
+
try:
|
| 467 |
+
self._default_result_or_future = ray.get(self._default_result_or_future)
|
| 468 |
+
except RayActorError: # error during initialization
|
| 469 |
+
self._default_result_or_future = None
|
| 470 |
+
if self._default_result_or_future and self.temporary_state.ray_actor:
|
| 471 |
+
self.set_location(
|
| 472 |
+
_Location(
|
| 473 |
+
self._default_result_or_future.get(NODE_IP),
|
| 474 |
+
self._default_result_or_future.get(PID),
|
| 475 |
+
)
|
| 476 |
+
)
|
| 477 |
+
return self._default_result_or_future
|
| 478 |
+
|
| 479 |
+
def resolve_config_placeholders(self, placeholder_resolvers: Dict[Tuple, Any]):
|
| 480 |
+
from ray.tune.impl.placeholder import resolve_placeholders
|
| 481 |
+
|
| 482 |
+
# Make a copy of the unresolved config before resolve it.
|
| 483 |
+
self.config = copy.deepcopy(self.__unresolved_config)
|
| 484 |
+
resolve_placeholders(self.config, placeholder_resolvers)
|
| 485 |
+
|
| 486 |
+
@property
|
| 487 |
+
def last_result(self) -> dict:
|
| 488 |
+
# The logic in here is as follows:
|
| 489 |
+
# 1. If the trial has reported at least once, last_result would have
|
| 490 |
+
# been set and therefore would not be empty. We can just return it.
|
| 491 |
+
# 2. If the trial has not reported at least once but we have the
|
| 492 |
+
# future for the default results dict, (obtained through
|
| 493 |
+
# Trainable.get_auto_filled_metrics), we get that future
|
| 494 |
+
# and return it.
|
| 495 |
+
# 3. In the worst case where we have nothing, we just set the
|
| 496 |
+
# trial_id and return that.
|
| 497 |
+
result = self.run_metadata.last_result
|
| 498 |
+
if not {k for k in result if k != TRIAL_ID}:
|
| 499 |
+
self._get_default_result_or_future()
|
| 500 |
+
result = self._default_result_or_future or result
|
| 501 |
+
result.setdefault(TRIAL_ID, self.trial_id)
|
| 502 |
+
return result
|
| 503 |
+
|
| 504 |
+
@property
|
| 505 |
+
def metric_analysis(self):
|
| 506 |
+
return self.run_metadata.metric_analysis
|
| 507 |
+
|
| 508 |
+
@property
|
| 509 |
+
def metric_n_steps(self):
|
| 510 |
+
return self.run_metadata.metric_n_steps
|
| 511 |
+
|
| 512 |
+
def get_ray_actor_ip(self) -> Optional[str]:
|
| 513 |
+
if self.temporary_state.location.hostname:
|
| 514 |
+
return self.temporary_state.location.hostname
|
| 515 |
+
|
| 516 |
+
if not self.temporary_state.ray_actor:
|
| 517 |
+
return None
|
| 518 |
+
|
| 519 |
+
hostname, pid = ray.get(
|
| 520 |
+
self.temporary_state.ray_actor.get_current_ip_pid.remote()
|
| 521 |
+
)
|
| 522 |
+
self.temporary_state.location = _Location(hostname, pid)
|
| 523 |
+
return self.temporary_state.location.hostname
|
| 524 |
+
|
| 525 |
+
@property
|
| 526 |
+
@Deprecated("Replaced by `local_experiment_path`")
|
| 527 |
+
def local_dir(self):
|
| 528 |
+
return self.local_experiment_path
|
| 529 |
+
|
| 530 |
+
@property
|
| 531 |
+
def experiment_dir_name(self):
|
| 532 |
+
return self.storage.experiment_dir_name
|
| 533 |
+
|
| 534 |
+
@property
|
| 535 |
+
def remote_experiment_path(self) -> str:
|
| 536 |
+
return self.storage.experiment_fs_path
|
| 537 |
+
|
| 538 |
+
@property
|
| 539 |
+
def local_experiment_path(self) -> str:
|
| 540 |
+
return self.storage.experiment_driver_staging_path
|
| 541 |
+
|
| 542 |
+
@property
|
| 543 |
+
@Deprecated("Replaced by `local_path`")
|
| 544 |
+
def logdir(self) -> Optional[str]:
|
| 545 |
+
# TODO(justinvyu): [Deprecated] Remove in 2.11.
|
| 546 |
+
raise DeprecationWarning("Use `local_path` instead of `logdir`.")
|
| 547 |
+
|
| 548 |
+
@property
|
| 549 |
+
def local_path(self) -> Optional[str]:
|
| 550 |
+
return self.storage.trial_driver_staging_path
|
| 551 |
+
|
| 552 |
+
@property
|
| 553 |
+
def path(self) -> Optional[str]:
|
| 554 |
+
return self.storage.trial_fs_path
|
| 555 |
+
|
| 556 |
+
@property
|
| 557 |
+
def has_reported_at_least_once(self) -> bool:
|
| 558 |
+
return bool(self.run_metadata.last_result)
|
| 559 |
+
|
| 560 |
+
@property
|
| 561 |
+
def node_ip(self):
|
| 562 |
+
return self.temporary_state.location.hostname
|
| 563 |
+
|
| 564 |
+
@property
|
| 565 |
+
def checkpoint_at_end(self):
|
| 566 |
+
config = self.run_metadata.checkpoint_manager.checkpoint_config
|
| 567 |
+
return config.checkpoint_at_end
|
| 568 |
+
|
| 569 |
+
@property
|
| 570 |
+
def checkpoint_freq(self):
|
| 571 |
+
config = self.run_metadata.checkpoint_manager.checkpoint_config
|
| 572 |
+
return config.checkpoint_frequency
|
| 573 |
+
|
| 574 |
+
@property
|
| 575 |
+
def latest_checkpoint_result(self) -> Optional[_TrainingResult]:
|
| 576 |
+
# NOTE: Fallback to the checkpoint passed in from `tune.run(restore)`
|
| 577 |
+
# if the trial hasn't saved any checkpoints itself yet.
|
| 578 |
+
return (
|
| 579 |
+
self.run_metadata.checkpoint_manager.latest_checkpoint_result
|
| 580 |
+
or self._restore_checkpoint_result
|
| 581 |
+
)
|
| 582 |
+
|
| 583 |
+
@property
|
| 584 |
+
def checkpoint(self) -> Optional[Checkpoint]:
|
| 585 |
+
"""Returns the most recent checkpoint if one has been saved."""
|
| 586 |
+
return (
|
| 587 |
+
self.latest_checkpoint_result.checkpoint
|
| 588 |
+
if self.latest_checkpoint_result
|
| 589 |
+
else None
|
| 590 |
+
)
|
| 591 |
+
|
| 592 |
+
@classmethod
|
| 593 |
+
def generate_id(cls):
|
| 594 |
+
return str(uuid.uuid4().hex)[:8]
|
| 595 |
+
|
| 596 |
+
def reset(self) -> "Trial":
|
| 597 |
+
# If there is `default_resource_request` associated with the trainable,
|
| 598 |
+
# clear `resources` and `placement_group_factory`.
|
| 599 |
+
# This is mainly relevant for RLlib tuning jobs, where we save users
|
| 600 |
+
# of the trouble to specify the resources themselves by having some
|
| 601 |
+
# default resources for popular RLlib algorithms.
|
| 602 |
+
trainable_cls = self.get_trainable_cls()
|
| 603 |
+
clear_resources = trainable_cls and trainable_cls.default_resource_request(
|
| 604 |
+
self.config
|
| 605 |
+
)
|
| 606 |
+
placement_group_factory = (
|
| 607 |
+
self.placement_group_factory if not clear_resources else None
|
| 608 |
+
)
|
| 609 |
+
|
| 610 |
+
checkpoint_config = self.run_metadata.checkpoint_manager.checkpoint_config
|
| 611 |
+
return Trial(
|
| 612 |
+
self.trainable_name,
|
| 613 |
+
config=self.config,
|
| 614 |
+
trial_id=None,
|
| 615 |
+
evaluated_params=self.evaluated_params,
|
| 616 |
+
experiment_tag=self.experiment_tag,
|
| 617 |
+
placement_group_factory=placement_group_factory,
|
| 618 |
+
stopping_criterion=self.stopping_criterion,
|
| 619 |
+
checkpoint_config=checkpoint_config,
|
| 620 |
+
export_formats=self.export_formats,
|
| 621 |
+
restore_path=self.restore_path,
|
| 622 |
+
trial_name_creator=self.trial_name_creator,
|
| 623 |
+
trial_dirname_creator=self.trial_dirname_creator,
|
| 624 |
+
log_to_file=self.log_to_file,
|
| 625 |
+
max_failures=self.max_failures,
|
| 626 |
+
storage=self.storage,
|
| 627 |
+
)
|
| 628 |
+
|
| 629 |
+
@Deprecated("Replaced by `init_local_path()`")
|
| 630 |
+
def init_logdir(self):
|
| 631 |
+
# TODO(justinvyu): [Deprecated] Remove in 2.11.
|
| 632 |
+
raise DeprecationWarning("Use `init_local_path` instead of `init_logdir`.")
|
| 633 |
+
|
| 634 |
+
def init_local_path(self):
|
| 635 |
+
"""Init logdir."""
|
| 636 |
+
if not self.relative_logdir:
|
| 637 |
+
self.relative_logdir = _create_unique_logdir_name(
|
| 638 |
+
str(self.local_experiment_path), self._generate_dirname()
|
| 639 |
+
)
|
| 640 |
+
# Populate the storage context with the trial dir name we just generated.
|
| 641 |
+
self.storage.trial_dir_name = self.relative_logdir
|
| 642 |
+
|
| 643 |
+
assert self.local_path
|
| 644 |
+
logdir_path = Path(self.local_path)
|
| 645 |
+
max_path_length = _get_max_path_length()
|
| 646 |
+
if len(str(logdir_path)) >= max_path_length:
|
| 647 |
+
logger.warning(
|
| 648 |
+
f"The path to the trial log directory is too long "
|
| 649 |
+
f"(max length: {max_path_length}. "
|
| 650 |
+
f"Consider using `trial_dirname_creator` to shorten the path. "
|
| 651 |
+
f"Path: {logdir_path}"
|
| 652 |
+
)
|
| 653 |
+
logdir_path.mkdir(parents=True, exist_ok=True)
|
| 654 |
+
|
| 655 |
+
self.invalidate_json_state()
|
| 656 |
+
|
| 657 |
+
def update_resources(self, resources: Union[dict, PlacementGroupFactory]):
|
| 658 |
+
"""EXPERIMENTAL: Updates the resource requirements.
|
| 659 |
+
|
| 660 |
+
Should only be called when the trial is not running.
|
| 661 |
+
|
| 662 |
+
Raises:
|
| 663 |
+
ValueError: if trial status is running.
|
| 664 |
+
"""
|
| 665 |
+
if self.status is Trial.RUNNING:
|
| 666 |
+
raise ValueError("Cannot update resources while Trial is running.")
|
| 667 |
+
|
| 668 |
+
placement_group_factory = resources
|
| 669 |
+
if isinstance(resources, dict):
|
| 670 |
+
placement_group_factory = resource_dict_to_pg_factory(resources)
|
| 671 |
+
|
| 672 |
+
self.placement_group_factory = placement_group_factory
|
| 673 |
+
|
| 674 |
+
self.invalidate_json_state()
|
| 675 |
+
|
| 676 |
+
def set_ray_actor(self, ray_actor):
|
| 677 |
+
self.temporary_state.ray_actor = ray_actor
|
| 678 |
+
if ray_actor:
|
| 679 |
+
# Do not block here, the result will be gotten when last_result
|
| 680 |
+
# property is accessed
|
| 681 |
+
self._default_result_or_future = ray_actor.get_auto_filled_metrics.remote(
|
| 682 |
+
debug_metrics_only=True
|
| 683 |
+
)
|
| 684 |
+
|
| 685 |
+
def set_location(self, location):
|
| 686 |
+
"""Sets the location of the trial."""
|
| 687 |
+
self.temporary_state.location = location
|
| 688 |
+
|
| 689 |
+
def set_status(self, status):
|
| 690 |
+
"""Sets the status of the trial."""
|
| 691 |
+
self.status = status
|
| 692 |
+
if status == Trial.RUNNING:
|
| 693 |
+
if self.run_metadata.start_time is None:
|
| 694 |
+
self.run_metadata.start_time = time.time()
|
| 695 |
+
self.invalidate_json_state()
|
| 696 |
+
|
| 697 |
+
def set_config(self, config):
|
| 698 |
+
self.config = config
|
| 699 |
+
self.invalidate_json_state()
|
| 700 |
+
|
| 701 |
+
def set_experiment_tag(self, experiment_tag):
|
| 702 |
+
self.experiment_tag = experiment_tag
|
| 703 |
+
self.invalidate_json_state()
|
| 704 |
+
|
| 705 |
+
def set_storage(self, new_storage: StorageContext):
|
| 706 |
+
"""Updates the storage context of the trial.
|
| 707 |
+
|
| 708 |
+
If the `storage_path` or `experiment_dir_name` has changed, then this setter
|
| 709 |
+
also updates the paths of all checkpoints tracked by the checkpoint manager.
|
| 710 |
+
This enables restoration from a checkpoint if the user moves the directory.
|
| 711 |
+
"""
|
| 712 |
+
original_storage = self.storage
|
| 713 |
+
|
| 714 |
+
checkpoint_manager = self.run_metadata.checkpoint_manager
|
| 715 |
+
|
| 716 |
+
for checkpoint_result in checkpoint_manager.best_checkpoint_results:
|
| 717 |
+
checkpoint_result.checkpoint = Checkpoint(
|
| 718 |
+
path=checkpoint_result.checkpoint.path.replace(
|
| 719 |
+
original_storage.trial_fs_path, new_storage.trial_fs_path, 1
|
| 720 |
+
),
|
| 721 |
+
filesystem=new_storage.storage_filesystem,
|
| 722 |
+
)
|
| 723 |
+
latest_checkpoint_result = checkpoint_manager.latest_checkpoint_result
|
| 724 |
+
if latest_checkpoint_result:
|
| 725 |
+
latest_checkpoint_result.checkpoint = Checkpoint(
|
| 726 |
+
path=latest_checkpoint_result.checkpoint.path.replace(
|
| 727 |
+
original_storage.trial_fs_path, new_storage.trial_fs_path, 1
|
| 728 |
+
),
|
| 729 |
+
filesystem=new_storage.storage_filesystem,
|
| 730 |
+
)
|
| 731 |
+
|
| 732 |
+
self.storage = new_storage
|
| 733 |
+
self.invalidate_json_state()
|
| 734 |
+
|
| 735 |
+
@property
|
| 736 |
+
def num_failures(self):
|
| 737 |
+
return self.run_metadata.num_failures
|
| 738 |
+
|
| 739 |
+
@property
|
| 740 |
+
def num_failures_after_restore(self):
|
| 741 |
+
return self.run_metadata.num_failures_after_restore
|
| 742 |
+
|
| 743 |
+
@property
|
| 744 |
+
def error_file(self):
|
| 745 |
+
if not self.local_path or not self.run_metadata.error_filename:
|
| 746 |
+
return None
|
| 747 |
+
return Path(self.local_path, self.run_metadata.error_filename).as_posix()
|
| 748 |
+
|
| 749 |
+
@property
|
| 750 |
+
def pickled_error_file(self):
|
| 751 |
+
if not self.local_path or not self.run_metadata.pickled_error_filename:
|
| 752 |
+
return None
|
| 753 |
+
return Path(
|
| 754 |
+
self.local_path, self.run_metadata.pickled_error_filename
|
| 755 |
+
).as_posix()
|
| 756 |
+
|
| 757 |
+
def get_pickled_error(self) -> Optional[Exception]:
|
| 758 |
+
"""Returns the pickled error object if it exists in storage.
|
| 759 |
+
|
| 760 |
+
This is a pickled version of the latest error that the trial encountered.
|
| 761 |
+
"""
|
| 762 |
+
error_filename = self.run_metadata.pickled_error_filename
|
| 763 |
+
if error_filename is None:
|
| 764 |
+
return None
|
| 765 |
+
|
| 766 |
+
fs = self.storage.storage_filesystem
|
| 767 |
+
pickled_error_fs_path = Path(
|
| 768 |
+
self.storage.trial_fs_path, error_filename
|
| 769 |
+
).as_posix()
|
| 770 |
+
|
| 771 |
+
if _exists_at_fs_path(fs=fs, fs_path=pickled_error_fs_path):
|
| 772 |
+
with fs.open_input_stream(pickled_error_fs_path) as f:
|
| 773 |
+
return cloudpickle.loads(f.readall())
|
| 774 |
+
return None
|
| 775 |
+
|
| 776 |
+
def get_error(self) -> Optional[TuneError]:
|
| 777 |
+
"""Returns the error text file trace as a TuneError object
|
| 778 |
+
if it exists in storage.
|
| 779 |
+
|
| 780 |
+
This is a text trace of the latest error that the trial encountered,
|
| 781 |
+
which is used in the case that the error is not picklable.
|
| 782 |
+
"""
|
| 783 |
+
error_filename = self.run_metadata.error_filename
|
| 784 |
+
if error_filename is None:
|
| 785 |
+
return None
|
| 786 |
+
|
| 787 |
+
fs = self.storage.storage_filesystem
|
| 788 |
+
txt_error_fs_path = Path(self.storage.trial_fs_path, error_filename).as_posix()
|
| 789 |
+
|
| 790 |
+
if _exists_at_fs_path(fs=fs, fs_path=txt_error_fs_path):
|
| 791 |
+
with fs.open_input_stream(txt_error_fs_path) as f:
|
| 792 |
+
return f.readall().decode()
|
| 793 |
+
return None
|
| 794 |
+
|
| 795 |
+
def _handle_restore_error(self, exc: Exception):
|
| 796 |
+
# For Restoration errors, we only increment the restore failure count
|
| 797 |
+
# if the number of failures exceeds the restore retry limit.
|
| 798 |
+
if self.temporary_state.num_restore_failures >= int(
|
| 799 |
+
os.environ.get("TUNE_RESTORE_RETRY_NUM", 0)
|
| 800 |
+
):
|
| 801 |
+
self.run_metadata.num_failures += 1
|
| 802 |
+
else:
|
| 803 |
+
self.temporary_state.num_restore_failures += 1
|
| 804 |
+
|
| 805 |
+
def _handle_ray_actor_error(self, exc: RayActorError):
|
| 806 |
+
count_preemption_errors = bool(
|
| 807 |
+
int(os.environ.get(RAY_TRAIN_COUNT_PREEMPTION_AS_FAILURE, "0"))
|
| 808 |
+
)
|
| 809 |
+
if not exc.preempted or count_preemption_errors:
|
| 810 |
+
# Only count non-preempted actor errors as failures.
|
| 811 |
+
self.run_metadata.num_failures += 1
|
| 812 |
+
|
| 813 |
+
def _handle_ray_task_error(self, exc: RayTaskError):
|
| 814 |
+
cause = exc.as_instanceof_cause()
|
| 815 |
+
if isinstance(cause, RayActorError):
|
| 816 |
+
# Handle the RayActorError directly (ex: Ray Train worker actor errors)
|
| 817 |
+
return self._handle_ray_actor_error(cause)
|
| 818 |
+
|
| 819 |
+
# Increment failures for all user errors (which get raised as RayTaskError)
|
| 820 |
+
self.run_metadata.num_failures += 1
|
| 821 |
+
|
| 822 |
+
def handle_error(
|
| 823 |
+
self, exc: Optional[Union[TuneError, RayTaskError, RayActorError]] = None
|
| 824 |
+
):
|
| 825 |
+
if self.is_restoring:
|
| 826 |
+
self._handle_restore_error(exc)
|
| 827 |
+
elif isinstance(exc, RayActorError):
|
| 828 |
+
self._handle_ray_actor_error(exc)
|
| 829 |
+
elif isinstance(exc, RayTaskError):
|
| 830 |
+
self._handle_ray_task_error(exc)
|
| 831 |
+
else:
|
| 832 |
+
self.run_metadata.num_failures += 1
|
| 833 |
+
|
| 834 |
+
if self.local_path:
|
| 835 |
+
self.run_metadata.error_filename = EXPR_ERROR_FILE
|
| 836 |
+
if isinstance(exc, (RayTaskError, RayActorError)):
|
| 837 |
+
# Piping through the actual error to result grid.
|
| 838 |
+
self.run_metadata.pickled_error_filename = EXPR_ERROR_PICKLE_FILE
|
| 839 |
+
with open(self.pickled_error_file, "wb") as f:
|
| 840 |
+
cloudpickle.dump(exc, f)
|
| 841 |
+
with open(self.error_file, "a+") as f:
|
| 842 |
+
f.write(
|
| 843 |
+
"Failure # {} (occurred at {})\n".format(
|
| 844 |
+
self.run_metadata.num_failures, date_str()
|
| 845 |
+
)
|
| 846 |
+
)
|
| 847 |
+
f.write(str(exc) + "\n")
|
| 848 |
+
self.run_metadata.invalidate_cache()
|
| 849 |
+
|
| 850 |
+
def should_stop(self, result):
|
| 851 |
+
"""Whether the given result meets this trial's stopping criteria."""
|
| 852 |
+
if result.get(DONE):
|
| 853 |
+
return True
|
| 854 |
+
|
| 855 |
+
for criterion, stop_value in self.stopping_criterion.items():
|
| 856 |
+
if isinstance(criterion, dict):
|
| 857 |
+
raise ValueError(
|
| 858 |
+
"Stopping criteria is now flattened by default. "
|
| 859 |
+
"Use forward slashes to nest values `key1/key2/key3`."
|
| 860 |
+
)
|
| 861 |
+
elif criterion not in result:
|
| 862 |
+
if log_once("tune_trial_stop_criterion_not_found"):
|
| 863 |
+
logger.warning(
|
| 864 |
+
f"Stopping criterion '{criterion}' not found in result dict! "
|
| 865 |
+
f"Available keys are {list(result.keys())}. If '{criterion}' is"
|
| 866 |
+
" never reported, the run will continue until training is "
|
| 867 |
+
"finished."
|
| 868 |
+
)
|
| 869 |
+
elif result[criterion] >= stop_value:
|
| 870 |
+
return True
|
| 871 |
+
return False
|
| 872 |
+
|
| 873 |
+
def should_checkpoint(self):
|
| 874 |
+
"""Whether this trial is due for checkpointing."""
|
| 875 |
+
result = self.last_result or {}
|
| 876 |
+
if result.get(DONE) and self.checkpoint_at_end:
|
| 877 |
+
return True
|
| 878 |
+
return (
|
| 879 |
+
self.checkpoint_freq
|
| 880 |
+
and result.get(TRAINING_ITERATION, 0) % self.checkpoint_freq == 0
|
| 881 |
+
)
|
| 882 |
+
|
| 883 |
+
def has_checkpoint(self) -> bool:
|
| 884 |
+
return self.checkpoint is not None
|
| 885 |
+
|
| 886 |
+
def on_checkpoint(self, checkpoint_result: _TrainingResult):
|
| 887 |
+
"""Hook for handling checkpoints taken by the Trainable.
|
| 888 |
+
|
| 889 |
+
Args:
|
| 890 |
+
checkpoint: Checkpoint taken.
|
| 891 |
+
"""
|
| 892 |
+
self.run_metadata.checkpoint_manager.register_checkpoint(checkpoint_result)
|
| 893 |
+
# Update the checkpoint index to keep the checkpoint index in sync.
|
| 894 |
+
# This index will get restored when the trial is restored and will
|
| 895 |
+
# be passed to the Trainable as the starting checkpoint index.
|
| 896 |
+
self.storage._update_checkpoint_index(checkpoint_result.metrics)
|
| 897 |
+
|
| 898 |
+
self.invalidate_json_state()
|
| 899 |
+
self.run_metadata.invalidate_cache()
|
| 900 |
+
|
| 901 |
+
def on_restore(self):
|
| 902 |
+
"""Handles restoration completion."""
|
| 903 |
+
assert self.is_restoring
|
| 904 |
+
self.run_metadata.last_result = self.temporary_state.restoring_from.metrics
|
| 905 |
+
self.run_metadata.last_result.setdefault("config", self.config)
|
| 906 |
+
self.temporary_state.restoring_from = None
|
| 907 |
+
self.temporary_state.num_restore_failures = 0
|
| 908 |
+
|
| 909 |
+
def should_recover(self):
|
| 910 |
+
"""Returns whether the trial qualifies for retrying.
|
| 911 |
+
|
| 912 |
+
`num_failures` should represent the number of times the trial has
|
| 913 |
+
failed *up to the moment this method is called.* If we've failed
|
| 914 |
+
5 times and `max_failures=5`, then we should recover, since
|
| 915 |
+
we only pass the limit on the 6th failure.
|
| 916 |
+
|
| 917 |
+
Note this may return true even when there is no checkpoint, either because
|
| 918 |
+
`self.checkpoint_freq` is `0` or because the trial failed before
|
| 919 |
+
a checkpoint has been made.
|
| 920 |
+
"""
|
| 921 |
+
return (
|
| 922 |
+
self.run_metadata.num_failures <= self.max_failures or self.max_failures < 0
|
| 923 |
+
)
|
| 924 |
+
|
| 925 |
+
def update_last_result(self, result):
|
| 926 |
+
if self.experiment_tag:
|
| 927 |
+
result.update(experiment_tag=self.experiment_tag)
|
| 928 |
+
|
| 929 |
+
self.set_location(_Location(result.get(NODE_IP), result.get(PID)))
|
| 930 |
+
self.run_metadata.last_result = result
|
| 931 |
+
self.run_metadata.last_result_time = time.time()
|
| 932 |
+
|
| 933 |
+
metric_result = self.last_result.copy()
|
| 934 |
+
for remove_metric in DEBUG_METRICS:
|
| 935 |
+
metric_result.pop(remove_metric, None)
|
| 936 |
+
|
| 937 |
+
for metric, value in flatten_dict(metric_result).items():
|
| 938 |
+
if isinstance(value, Number):
|
| 939 |
+
self.run_metadata.update_metric(
|
| 940 |
+
metric, value, step=result.get("training_iteration")
|
| 941 |
+
)
|
| 942 |
+
|
| 943 |
+
def get_trainable_cls(self):
|
| 944 |
+
if self.stub:
|
| 945 |
+
return None
|
| 946 |
+
return get_trainable_cls(self.trainable_name)
|
| 947 |
+
|
| 948 |
+
def is_finished(self):
|
| 949 |
+
return self.status in [Trial.ERROR, Trial.TERMINATED]
|
| 950 |
+
|
| 951 |
+
@property
|
| 952 |
+
def is_restoring(self):
|
| 953 |
+
return self.temporary_state.restoring_from is not None
|
| 954 |
+
|
| 955 |
+
@property
|
| 956 |
+
def is_saving(self):
|
| 957 |
+
return self.temporary_state.saving_to is not None
|
| 958 |
+
|
| 959 |
+
def __repr__(self):
|
| 960 |
+
return self._trainable_name(include_trial_id=True)
|
| 961 |
+
|
| 962 |
+
def __str__(self):
|
| 963 |
+
return self._trainable_name(include_trial_id=True)
|
| 964 |
+
|
| 965 |
+
def _trainable_name(self, include_trial_id=False):
|
| 966 |
+
"""Combines ``env`` with ``trainable_name`` and ``trial_id``.
|
| 967 |
+
|
| 968 |
+
Can be overridden with a custom string creator.
|
| 969 |
+
"""
|
| 970 |
+
if self.custom_trial_name:
|
| 971 |
+
return self.custom_trial_name
|
| 972 |
+
|
| 973 |
+
if "env" in self.config:
|
| 974 |
+
env = self.config["env"]
|
| 975 |
+
if isinstance(env, type):
|
| 976 |
+
env = env.__name__
|
| 977 |
+
identifier = "{}_{}".format(self.trainable_name, env)
|
| 978 |
+
else:
|
| 979 |
+
identifier = self.trainable_name
|
| 980 |
+
if include_trial_id:
|
| 981 |
+
identifier += "_" + self.trial_id
|
| 982 |
+
return identifier.replace("/", "_")
|
| 983 |
+
|
| 984 |
+
def _generate_dirname(self):
|
| 985 |
+
if self.custom_dirname:
|
| 986 |
+
generated_dirname = self.custom_dirname
|
| 987 |
+
else:
|
| 988 |
+
MAX_LEN_IDENTIFIER = int(os.environ.get("TUNE_MAX_LEN_IDENTIFIER", "130"))
|
| 989 |
+
generated_dirname = f"{str(self)}_{self.experiment_tag}"
|
| 990 |
+
generated_dirname = generated_dirname[:MAX_LEN_IDENTIFIER]
|
| 991 |
+
generated_dirname += f"_{date_str()}"
|
| 992 |
+
# This is the file path used by rsync. ['/', '(', ')'] are not allowed.
|
| 993 |
+
return re.sub("[/()]", "_", generated_dirname)
|
| 994 |
+
|
| 995 |
+
def invalidate_json_state(self):
|
| 996 |
+
self._state_json = None
|
| 997 |
+
|
| 998 |
+
def get_json_state(self) -> Tuple[str, str]:
|
| 999 |
+
if self._state_json is None:
|
| 1000 |
+
state = self.__getstate__()
|
| 1001 |
+
state.pop("run_metadata", None)
|
| 1002 |
+
self._state_json = json.dumps(state, indent=2, cls=TuneFunctionEncoder)
|
| 1003 |
+
|
| 1004 |
+
runtime_metadata_json = self.run_metadata.get_json_state()
|
| 1005 |
+
|
| 1006 |
+
return self._state_json, runtime_metadata_json
|
| 1007 |
+
|
| 1008 |
+
@classmethod
|
| 1009 |
+
def from_json_state(cls, json_state: str, stub: bool = False) -> "Trial":
|
| 1010 |
+
state = json.loads(json_state, cls=TuneFunctionDecoder)
|
| 1011 |
+
|
| 1012 |
+
new_trial = Trial(
|
| 1013 |
+
state["trainable_name"],
|
| 1014 |
+
stub=stub,
|
| 1015 |
+
_setup_default_resource=False,
|
| 1016 |
+
)
|
| 1017 |
+
|
| 1018 |
+
new_trial.__setstate__(state)
|
| 1019 |
+
|
| 1020 |
+
return new_trial
|
| 1021 |
+
|
| 1022 |
+
def restore_run_metadata(self, run_metadata: str):
|
| 1023 |
+
self.run_metadata = _TrainingRunMetadata.from_json_state(run_metadata)
|
| 1024 |
+
|
| 1025 |
+
@classmethod
|
| 1026 |
+
def from_directory(
|
| 1027 |
+
cls, path: Union[str, os.PathLike], stub: bool = False
|
| 1028 |
+
) -> "Trial":
|
| 1029 |
+
metadata_path = Path(path, TRIAL_STATE_FILENAME)
|
| 1030 |
+
if not metadata_path.exists():
|
| 1031 |
+
raise FileNotFoundError(
|
| 1032 |
+
f"Can't restore trial from path: File `{metadata_path}` not found."
|
| 1033 |
+
)
|
| 1034 |
+
|
| 1035 |
+
json_state = metadata_path.read_text()
|
| 1036 |
+
return cls.from_json_state(json_state, stub=stub)
|
| 1037 |
+
|
| 1038 |
+
def __getstate__(self):
|
| 1039 |
+
"""Memento generator for Trial.
|
| 1040 |
+
|
| 1041 |
+
Sets RUNNING trials to PENDING.
|
| 1042 |
+
Note this can only occur if the trial holds a PERSISTENT checkpoint.
|
| 1043 |
+
"""
|
| 1044 |
+
state = self.__dict__.copy()
|
| 1045 |
+
|
| 1046 |
+
for key in self._nonjson_fields:
|
| 1047 |
+
state[key] = binary_to_hex(cloudpickle.dumps(state.get(key)))
|
| 1048 |
+
|
| 1049 |
+
state.pop("temporary_state", None)
|
| 1050 |
+
|
| 1051 |
+
state["_state_json"] = None
|
| 1052 |
+
state["_default_result_or_future"] = None
|
| 1053 |
+
|
| 1054 |
+
return state
|
| 1055 |
+
|
| 1056 |
+
def __setstate__(self, state):
|
| 1057 |
+
if state["status"] == Trial.RUNNING:
|
| 1058 |
+
state["status"] = Trial.PENDING
|
| 1059 |
+
for key in self._nonjson_fields:
|
| 1060 |
+
if key in state:
|
| 1061 |
+
state[key] = cloudpickle.loads(hex_to_binary(state[key]))
|
| 1062 |
+
|
| 1063 |
+
# Ensure that stub doesn't get overriden
|
| 1064 |
+
stub = state.pop("stub", True)
|
| 1065 |
+
self.__dict__.update(state)
|
| 1066 |
+
self.stub = stub or getattr(self, "stub", False)
|
| 1067 |
+
|
| 1068 |
+
if not self.stub:
|
| 1069 |
+
validate_trainable(self.trainable_name)
|
| 1070 |
+
|
| 1071 |
+
self.temporary_state = _TemporaryTrialState()
|
| 1072 |
+
|
| 1073 |
+
assert self.placement_group_factory
|
.venv/lib/python3.11/site-packages/ray/tune/integration/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/ray/tune/integration/__pycache__/keras.cpython-311.pyc
ADDED
|
Binary file (1.64 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/tune/integration/__pycache__/lightgbm.cpython-311.pyc
ADDED
|
Binary file (1.02 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/tune/integration/__pycache__/pytorch_lightning.cpython-311.pyc
ADDED
|
Binary file (11.5 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/tune/integration/__pycache__/ray_train.cpython-311.pyc
ADDED
|
Binary file (1.69 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/tune/integration/__pycache__/xgboost.cpython-311.pyc
ADDED
|
Binary file (1.01 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/tune/integration/keras.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
_DEPRECATION_MESSAGE = (
|
| 2 |
+
"The `ray.tune.integration.keras` module is deprecated in favor of "
|
| 3 |
+
"`ray.train.tensorflow.keras.ReportCheckpointCallback`."
|
| 4 |
+
)
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class TuneReportCallback:
|
| 8 |
+
"""Deprecated.
|
| 9 |
+
Use :class:`ray.train.tensorflow.keras.ReportCheckpointCallback` instead."""
|
| 10 |
+
|
| 11 |
+
def __new__(cls, *args, **kwargs):
|
| 12 |
+
raise DeprecationWarning(_DEPRECATION_MESSAGE)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class _TuneCheckpointCallback:
|
| 16 |
+
"""Deprecated.
|
| 17 |
+
Use :class:`ray.train.tensorflow.keras.ReportCheckpointCallback` instead."""
|
| 18 |
+
|
| 19 |
+
def __new__(cls, *args, **kwargs):
|
| 20 |
+
raise DeprecationWarning(_DEPRECATION_MESSAGE)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class TuneReportCheckpointCallback:
|
| 24 |
+
"""Deprecated.
|
| 25 |
+
Use :class:`ray.train.tensorflow.keras.ReportCheckpointCallback` instead."""
|
| 26 |
+
|
| 27 |
+
def __new__(cls, *args, **kwargs):
|
| 28 |
+
raise DeprecationWarning(_DEPRECATION_MESSAGE)
|