koichi12 commited on
Commit
68246e2
·
verified ·
1 Parent(s): 39cf1df

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .venv/lib/python3.11/site-packages/ray/tune/analysis/__init__.py +3 -0
  2. .venv/lib/python3.11/site-packages/ray/tune/analysis/__pycache__/__init__.cpython-311.pyc +0 -0
  3. .venv/lib/python3.11/site-packages/ray/tune/analysis/__pycache__/experiment_analysis.cpython-311.pyc +0 -0
  4. .venv/lib/python3.11/site-packages/ray/tune/analysis/experiment_analysis.py +678 -0
  5. .venv/lib/python3.11/site-packages/ray/tune/impl/__init__.py +0 -0
  6. .venv/lib/python3.11/site-packages/ray/tune/impl/__pycache__/__init__.cpython-311.pyc +0 -0
  7. .venv/lib/python3.11/site-packages/ray/tune/impl/__pycache__/config.cpython-311.pyc +0 -0
  8. .venv/lib/python3.11/site-packages/ray/tune/impl/__pycache__/out_of_band_serialize_dataset.cpython-311.pyc +0 -0
  9. .venv/lib/python3.11/site-packages/ray/tune/impl/__pycache__/placeholder.cpython-311.pyc +0 -0
  10. .venv/lib/python3.11/site-packages/ray/tune/impl/__pycache__/test_utils.cpython-311.pyc +0 -0
  11. .venv/lib/python3.11/site-packages/ray/tune/impl/__pycache__/tuner_internal.cpython-311.pyc +0 -0
  12. .venv/lib/python3.11/site-packages/ray/tune/impl/config.py +46 -0
  13. .venv/lib/python3.11/site-packages/ray/tune/impl/out_of_band_serialize_dataset.py +33 -0
  14. .venv/lib/python3.11/site-packages/ray/tune/impl/placeholder.py +244 -0
  15. .venv/lib/python3.11/site-packages/ray/tune/impl/test_utils.py +66 -0
  16. .venv/lib/python3.11/site-packages/ray/tune/impl/tuner_internal.py +669 -0
  17. .venv/lib/python3.11/site-packages/ray/tune/schedulers/__pycache__/__init__.cpython-311.pyc +0 -0
  18. .venv/lib/python3.11/site-packages/ray/tune/schedulers/__pycache__/resource_changing_scheduler.cpython-311.pyc +0 -0
  19. .venv/lib/python3.11/site-packages/ray/tune/search/__init__.py +153 -0
  20. .venv/lib/python3.11/site-packages/ray/tune/search/__pycache__/concurrency_limiter.cpython-311.pyc +0 -0
  21. .venv/lib/python3.11/site-packages/ray/tune/search/__pycache__/sample.cpython-311.pyc +0 -0
  22. .venv/lib/python3.11/site-packages/ray/tune/search/__pycache__/search_algorithm.cpython-311.pyc +0 -0
  23. .venv/lib/python3.11/site-packages/ray/tune/search/__pycache__/search_generator.cpython-311.pyc +0 -0
  24. .venv/lib/python3.11/site-packages/ray/tune/search/_mock.py +55 -0
  25. .venv/lib/python3.11/site-packages/ray/tune/search/ax/__pycache__/__init__.cpython-311.pyc +0 -0
  26. .venv/lib/python3.11/site-packages/ray/tune/search/ax/__pycache__/ax_search.cpython-311.pyc +0 -0
  27. .venv/lib/python3.11/site-packages/ray/tune/search/basic_variant.py +421 -0
  28. .venv/lib/python3.11/site-packages/ray/tune/search/concurrency_limiter.py +176 -0
  29. .venv/lib/python3.11/site-packages/ray/tune/search/hebo/__init__.py +3 -0
  30. .venv/lib/python3.11/site-packages/ray/tune/search/hebo/__pycache__/__init__.cpython-311.pyc +0 -0
  31. .venv/lib/python3.11/site-packages/ray/tune/search/hebo/__pycache__/hebo_search.cpython-311.pyc +0 -0
  32. .venv/lib/python3.11/site-packages/ray/tune/search/nevergrad/__init__.py +3 -0
  33. .venv/lib/python3.11/site-packages/ray/tune/search/nevergrad/__pycache__/__init__.cpython-311.pyc +0 -0
  34. .venv/lib/python3.11/site-packages/ray/tune/search/nevergrad/__pycache__/nevergrad_search.cpython-311.pyc +0 -0
  35. .venv/lib/python3.11/site-packages/ray/tune/search/nevergrad/nevergrad_search.py +373 -0
  36. .venv/lib/python3.11/site-packages/ray/tune/search/optuna/__init__.py +3 -0
  37. .venv/lib/python3.11/site-packages/ray/tune/search/optuna/__pycache__/__init__.cpython-311.pyc +0 -0
  38. .venv/lib/python3.11/site-packages/ray/tune/search/optuna/__pycache__/optuna_search.cpython-311.pyc +0 -0
  39. .venv/lib/python3.11/site-packages/ray/tune/search/optuna/optuna_search.py +733 -0
  40. .venv/lib/python3.11/site-packages/ray/tune/search/repeater.py +199 -0
  41. .venv/lib/python3.11/site-packages/ray/tune/search/sample.py +742 -0
  42. .venv/lib/python3.11/site-packages/ray/tune/search/search_algorithm.py +127 -0
  43. .venv/lib/python3.11/site-packages/ray/tune/search/search_generator.py +222 -0
  44. .venv/lib/python3.11/site-packages/ray/tune/search/searcher.py +597 -0
  45. .venv/lib/python3.11/site-packages/ray/tune/search/util.py +31 -0
  46. .venv/lib/python3.11/site-packages/ray/tune/search/variant_generator.py +523 -0
  47. .venv/lib/python3.11/site-packages/ray/tune/stopper/__init__.py +18 -0
  48. .venv/lib/python3.11/site-packages/ray/tune/stopper/__pycache__/__init__.cpython-311.pyc +0 -0
  49. .venv/lib/python3.11/site-packages/ray/tune/stopper/__pycache__/experiment_plateau.cpython-311.pyc +0 -0
  50. .venv/lib/python3.11/site-packages/ray/tune/stopper/__pycache__/function_stopper.cpython-311.pyc +0 -0
.venv/lib/python3.11/site-packages/ray/tune/analysis/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from ray.tune.analysis.experiment_analysis import ExperimentAnalysis
2
+
3
+ __all__ = ["ExperimentAnalysis"]
.venv/lib/python3.11/site-packages/ray/tune/analysis/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (310 Bytes). View file
 
.venv/lib/python3.11/site-packages/ray/tune/analysis/__pycache__/experiment_analysis.cpython-311.pyc ADDED
Binary file (35.8 kB). View file
 
.venv/lib/python3.11/site-packages/ray/tune/analysis/experiment_analysis.py ADDED
@@ -0,0 +1,678 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import io
3
+ import json
4
+ import logging
5
+ import os
6
+ from numbers import Number
7
+ from pathlib import Path
8
+ from typing import Any, Dict, List, Optional, Tuple, Union
9
+
10
+ import pyarrow.fs
11
+
12
+ from ray.air.constants import EXPR_PROGRESS_FILE, EXPR_RESULT_FILE, TRAINING_ITERATION
13
+ from ray.train import Checkpoint
14
+ from ray.train._internal.storage import _exists_at_fs_path, get_fs_and_path
15
+ from ray.tune.execution.experiment_state import _find_newest_experiment_checkpoint
16
+ from ray.tune.execution.tune_controller import TuneController
17
+ from ray.tune.experiment import Trial
18
+ from ray.tune.result import CONFIG_PREFIX, DEFAULT_METRIC
19
+ from ray.tune.utils import flatten_dict
20
+ from ray.tune.utils.serialization import TuneFunctionDecoder
21
+ from ray.tune.utils.util import is_nan, is_nan_or_inf, unflattened_lookup
22
+ from ray.util.annotations import PublicAPI
23
+
24
+ try:
25
+ import pandas as pd
26
+ from pandas import DataFrame
27
+ except ImportError:
28
+ pd = None
29
+ DataFrame = None
30
+
31
+
32
+ logger = logging.getLogger(__name__)
33
+
34
+
35
+ @PublicAPI(stability="beta")
36
+ class ExperimentAnalysis:
37
+ """Analyze results from a Ray Train/Tune experiment.
38
+
39
+ To use this class, the run must store the history of reported metrics
40
+ in log files (e.g., `result.json` and `progress.csv`).
41
+ This is the default behavior, unless default loggers are explicitly excluded
42
+ with the `TUNE_DISABLE_AUTO_CALLBACK_LOGGERS=1` environment variable.
43
+
44
+ Parameters:
45
+ experiment_checkpoint_path: Path to an `experiment_state.json` file,
46
+ or a directory that contains an `experiment_state.json` file.
47
+ default_metric: Default metric for comparing results. Can be
48
+ overwritten with the ``metric`` parameter in the respective
49
+ functions.
50
+ default_mode: Default mode for comparing results. Has to be one
51
+ of [min, max]. Can be overwritten with the ``mode`` parameter
52
+ in the respective functions.
53
+ trials: List of trials that can be accessed via `analysis.trials`.
54
+ """
55
+
56
+ def __init__(
57
+ self,
58
+ experiment_checkpoint_path: Union[str, os.PathLike],
59
+ *,
60
+ storage_filesystem: Optional[pyarrow.fs.FileSystem] = None,
61
+ trials: Optional[List[Trial]] = None,
62
+ default_metric: Optional[str] = None,
63
+ default_mode: Optional[str] = None,
64
+ ):
65
+ self.default_metric = default_metric
66
+ if default_mode and default_mode not in ["min", "max"]:
67
+ raise ValueError("`default_mode` has to be None or one of [min, max]")
68
+ self.default_mode = default_mode
69
+ if self.default_metric is None and self.default_mode is not None:
70
+ # If only a mode was passed, use anonymous metric
71
+ self.default_metric = DEFAULT_METRIC
72
+
73
+ # Resolve the filesystem if not specified.
74
+ if storage_filesystem:
75
+ self._fs = storage_filesystem
76
+ else:
77
+ self._fs, experiment_checkpoint_path = get_fs_and_path(
78
+ experiment_checkpoint_path
79
+ )
80
+
81
+ # Find the json state file.
82
+ experiment_checkpoint_path = str(experiment_checkpoint_path)
83
+ if experiment_checkpoint_path.endswith(".json"):
84
+ self._experiment_fs_path = os.path.dirname(experiment_checkpoint_path)
85
+ self._experiment_json_fs_path = experiment_checkpoint_path
86
+ else:
87
+ self._experiment_fs_path = experiment_checkpoint_path
88
+
89
+ experiment_json_fs_path = _find_newest_experiment_checkpoint(
90
+ experiment_path=self._experiment_fs_path, fs=self._fs
91
+ )
92
+ if experiment_json_fs_path is None:
93
+ pattern = TuneController.CKPT_FILE_TMPL.format("*")
94
+ raise ValueError(
95
+ f"No experiment snapshot file of form '{pattern}' was found at: "
96
+ f"({self._fs.type_name}, {self._experiment_fs_path})\n"
97
+ "Please check if you specified the correct experiment path, "
98
+ "which should be a combination of the `storage_path` and `name` "
99
+ "specified in your run."
100
+ )
101
+
102
+ self._experiment_json_fs_path = experiment_json_fs_path
103
+
104
+ self.trials = trials or self._load_trials()
105
+ self._trial_dataframes = self._fetch_trial_dataframes()
106
+ self._configs = self.get_all_configs()
107
+
108
+ def _load_trials(self) -> List[Trial]:
109
+ with self._fs.open_input_stream(self._experiment_json_fs_path) as f:
110
+ experiment_state = json.loads(f.readall(), cls=TuneFunctionDecoder)
111
+
112
+ experiment_fs_path = Path(self._experiment_fs_path)
113
+
114
+ trials = []
115
+ trial_states = experiment_state["trial_data"]
116
+ for trial_json_state, trial_runtime_metadata in trial_states:
117
+ trial = Trial.from_json_state(trial_json_state, stub=True)
118
+ trial.restore_run_metadata(trial_runtime_metadata)
119
+
120
+ new_storage = copy.copy(trial.storage)
121
+ new_storage.storage_fs_path = experiment_fs_path.parent.as_posix()
122
+ new_storage.storage_filesystem = self._fs
123
+ new_storage.experiment_dir_name = experiment_fs_path.name
124
+ trial.set_storage(new_storage)
125
+
126
+ trials.append(trial)
127
+ return trials
128
+
129
+ def _fetch_trial_dataframe(self, trial: Trial) -> DataFrame:
130
+ force_dtype = {"trial_id": str} # Never convert trial_id to float.
131
+
132
+ # If there were no reported results, there will be no files into a DataFrame
133
+ if trial.last_result is None:
134
+ return DataFrame()
135
+
136
+ json_fs_path = Path(trial.storage.trial_fs_path, EXPR_RESULT_FILE).as_posix()
137
+ csv_fs_path = Path(trial.storage.trial_fs_path, EXPR_PROGRESS_FILE).as_posix()
138
+ # Prefer reading the JSON if it exists.
139
+ if _exists_at_fs_path(trial.storage.storage_filesystem, json_fs_path):
140
+ with trial.storage.storage_filesystem.open_input_stream(json_fs_path) as f:
141
+ content = f.readall().decode("utf-8").rstrip("\n")
142
+ if not content:
143
+ return DataFrame()
144
+ json_list = [json.loads(row) for row in content.split("\n")]
145
+ df = pd.json_normalize(json_list, sep="/")
146
+ # Fallback to reading the CSV.
147
+ elif _exists_at_fs_path(trial.storage.storage_filesystem, csv_fs_path):
148
+ with trial.storage.storage_filesystem.open_input_stream(csv_fs_path) as f:
149
+ csv_str = f.readall().decode("utf-8")
150
+ df = pd.read_csv(io.StringIO(csv_str), dtype=force_dtype)
151
+ else:
152
+ raise FileNotFoundError(
153
+ f"Could not fetch metrics for {trial}: both {EXPR_RESULT_FILE} and "
154
+ f"{EXPR_PROGRESS_FILE} were not found at {trial.storage.trial_fs_path}"
155
+ )
156
+
157
+ return df
158
+
159
+ def _fetch_trial_dataframes(self) -> Dict[str, DataFrame]:
160
+ """Fetches trial dataframes from files.
161
+
162
+ Returns:
163
+ A dictionary mapping trial_id -> pd.DataFrame
164
+ """
165
+ failures = []
166
+
167
+ trial_dfs = {}
168
+ for trial in self.trials:
169
+ try:
170
+ trial_dfs[trial.trial_id] = self._fetch_trial_dataframe(trial)
171
+ except Exception as e:
172
+ failures.append((trial, e))
173
+ trial_dfs[trial.trial_id] = DataFrame()
174
+ continue
175
+
176
+ if failures:
177
+ fail_str = "\n".join(
178
+ [f"- {trial}: {repr(error)}" for trial, error in failures]
179
+ )
180
+ logger.warning(
181
+ f"Failed to fetch metrics for {len(failures)} trial(s):\n{fail_str}"
182
+ )
183
+ return trial_dfs
184
+
185
+ def get_all_configs(self, prefix: bool = False) -> Dict[str, Dict]:
186
+ """Returns all trial hyperparameter configurations.
187
+
188
+ Args:
189
+ prefix: If True, flattens the config dict
190
+ and prepends `config/`.
191
+
192
+ Returns:
193
+ Dict[str, Dict]: Mapping trial_id -> config dict
194
+ """
195
+ return {
196
+ trial.trial_id: (
197
+ flatten_dict({CONFIG_PREFIX: trial.config}) if prefix else trial.config
198
+ )
199
+ for trial in self.trials
200
+ }
201
+
202
+ @property
203
+ def experiment_path(self) -> str:
204
+ """Path pointing to the experiment directory on persistent storage.
205
+
206
+ This can point to a remote storage location (e.g. S3) or to a local
207
+ location (path on the head node)."""
208
+ return self._experiment_fs_path
209
+
210
+ @property
211
+ def best_trial(self) -> Trial:
212
+ """Get the best trial of the experiment
213
+
214
+ The best trial is determined by comparing the last trial results
215
+ using the `metric` and `mode` parameters passed to `tune.run()`.
216
+
217
+ If you didn't pass these parameters, use
218
+ `get_best_trial(metric, mode, scope)` instead.
219
+ """
220
+ if not self.default_metric or not self.default_mode:
221
+ raise ValueError(
222
+ "To fetch the `best_trial`, pass a `metric` and `mode` "
223
+ "parameter to `tune.run()`. Alternatively, use the "
224
+ "`get_best_trial(metric, mode)` method to set the metric "
225
+ "and mode explicitly."
226
+ )
227
+ return self.get_best_trial(self.default_metric, self.default_mode)
228
+
229
+ @property
230
+ def best_config(self) -> Dict:
231
+ """Get the config of the best trial of the experiment
232
+
233
+ The best trial is determined by comparing the last trial results
234
+ using the `metric` and `mode` parameters passed to `tune.run()`.
235
+
236
+ If you didn't pass these parameters, use
237
+ `get_best_config(metric, mode, scope)` instead.
238
+ """
239
+ if not self.default_metric or not self.default_mode:
240
+ raise ValueError(
241
+ "To fetch the `best_config`, pass a `metric` and `mode` "
242
+ "parameter to `tune.run()`. Alternatively, use the "
243
+ "`get_best_config(metric, mode)` method to set the metric "
244
+ "and mode explicitly."
245
+ )
246
+ return self.get_best_config(self.default_metric, self.default_mode)
247
+
248
+ @property
249
+ def best_checkpoint(self) -> Checkpoint:
250
+ """Get the checkpoint path of the best trial of the experiment
251
+
252
+ The best trial is determined by comparing the last trial results
253
+ using the `metric` and `mode` parameters passed to `tune.run()`.
254
+
255
+ If you didn't pass these parameters, use
256
+ `get_best_checkpoint(trial, metric, mode)` instead.
257
+
258
+ Returns:
259
+ :class:`Checkpoint <ray.train.Checkpoint>` object.
260
+ """
261
+ if not self.default_metric or not self.default_mode:
262
+ raise ValueError(
263
+ "To fetch the `best_checkpoint`, pass a `metric` and `mode` "
264
+ "parameter to `tune.run()`. Alternatively, use the "
265
+ "`get_best_checkpoint(trial, metric, mode)` method to set the "
266
+ "metric and mode explicitly."
267
+ )
268
+ best_trial = self.best_trial
269
+ if not best_trial:
270
+ raise ValueError(
271
+ f"No best trial found. Please check if you specified the "
272
+ f"correct default metric ({self.default_metric}) and mode "
273
+ f"({self.default_mode})."
274
+ )
275
+ return self.get_best_checkpoint(
276
+ best_trial, self.default_metric, self.default_mode
277
+ )
278
+
279
+ @property
280
+ def best_dataframe(self) -> DataFrame:
281
+ """Get the full result dataframe of the best trial of the experiment
282
+
283
+ The best trial is determined by comparing the last trial results
284
+ using the `metric` and `mode` parameters passed to `tune.run()`.
285
+
286
+ If you didn't pass these parameters, use
287
+ `get_best_trial(metric, mode)` and use it to look for the dataframe
288
+ in the `self.trial_dataframes` dict.
289
+ """
290
+ if not self.default_metric or not self.default_mode:
291
+ raise ValueError(
292
+ "To fetch the `best_result`, pass a `metric` and `mode` "
293
+ "parameter to `tune.run()`."
294
+ )
295
+ return self.trial_dataframes[self.best_trial.trial_id]
296
+
297
+ @property
298
+ def best_result(self) -> Dict:
299
+ """Get the last result of the best trial of the experiment
300
+
301
+ The best trial is determined by comparing the last trial results
302
+ using the `metric` and `mode` parameters passed to `tune.run()`.
303
+
304
+ If you didn't pass these parameters, use
305
+ `get_best_trial(metric, mode, scope).last_result` instead.
306
+ """
307
+ if not self.default_metric or not self.default_mode:
308
+ raise ValueError(
309
+ "To fetch the `best_result`, pass a `metric` and `mode` "
310
+ "parameter to `tune.run()`. Alternatively, use "
311
+ "`get_best_trial(metric, mode).last_result` to set "
312
+ "the metric and mode explicitly and fetch the last result."
313
+ )
314
+ return self.best_trial.last_result
315
+
316
+ def _delimiter(self):
317
+ return os.environ.get("TUNE_RESULT_DELIM", "/")
318
+
319
+ @property
320
+ def best_result_df(self) -> DataFrame:
321
+ """Get the best result of the experiment as a pandas dataframe.
322
+
323
+ The best trial is determined by comparing the last trial results
324
+ using the `metric` and `mode` parameters passed to `tune.run()`.
325
+
326
+ If you didn't pass these parameters, use
327
+ `get_best_trial(metric, mode, scope).last_result` instead.
328
+ """
329
+ if not pd:
330
+ raise ValueError(
331
+ "`best_result_df` requires pandas. Install with "
332
+ "`pip install pandas`."
333
+ )
334
+
335
+ best_result = flatten_dict(self.best_result, delimiter=self._delimiter())
336
+ return pd.DataFrame.from_records([best_result], index="trial_id")
337
+
338
+ @property
339
+ def results(self) -> Dict[str, Dict]:
340
+ """Get the last result of the all trials of the experiment"""
341
+ return {trial.trial_id: trial.last_result for trial in self.trials}
342
+
343
+ @property
344
+ def results_df(self) -> DataFrame:
345
+ """Get all the last results as a pandas dataframe."""
346
+ if not pd:
347
+ raise ValueError(
348
+ "`results_df` requires pandas. Install with `pip install pandas`."
349
+ )
350
+ return pd.DataFrame.from_records(
351
+ [
352
+ flatten_dict(trial.last_result, delimiter=self._delimiter())
353
+ for trial in self.trials
354
+ ],
355
+ index="trial_id",
356
+ )
357
+
358
+ @property
359
+ def trial_dataframes(self) -> Dict[str, DataFrame]:
360
+ """List of all dataframes of the trials.
361
+
362
+ Each dataframe is indexed by iterations and contains reported
363
+ metrics.
364
+ """
365
+ return self._trial_dataframes
366
+
367
+ def dataframe(
368
+ self, metric: Optional[str] = None, mode: Optional[str] = None
369
+ ) -> DataFrame:
370
+ """Returns a pandas.DataFrame object constructed from the trials.
371
+
372
+ This function will look through all observed results of each trial
373
+ and return the one corresponding to the passed ``metric`` and
374
+ ``mode``: If ``mode=min``, it returns the result with the lowest
375
+ *ever* observed ``metric`` for this trial (this is not necessarily
376
+ the last)! For ``mode=max``, it's the highest, respectively. If
377
+ ``metric=None`` or ``mode=None``, the last result will be returned.
378
+
379
+ Args:
380
+ metric: Key for trial info to order on. If None, uses last result.
381
+ mode: One of [None, "min", "max"].
382
+
383
+ Returns:
384
+ pd.DataFrame: Constructed from a result dict of each trial.
385
+ """
386
+ # Do not validate metric/mode here or set from default metric/mode!
387
+ # Otherwise we will get confusing results as the lowest ever observed
388
+ # result may not be the last result.
389
+ if mode and mode not in ["min", "max"]:
390
+ raise ValueError("If set, `mode` has to be one of [min, max]")
391
+
392
+ if mode and not metric:
393
+ raise ValueError(
394
+ "If a `mode` is passed to `ExperimentAnalysis.dataframe(),"
395
+ " you'll also have to pass a `metric`!"
396
+ )
397
+
398
+ rows = self._retrieve_rows(metric=metric, mode=mode)
399
+ all_configs = self.get_all_configs(prefix=True)
400
+ for path, config in all_configs.items():
401
+ if path in rows:
402
+ rows[path].update(config)
403
+ rows[path].update(logdir=path)
404
+ return pd.DataFrame(list(rows.values()))
405
+
406
+ def _get_trial_checkpoints_with_metric(
407
+ self, trial: Trial, metric: Optional[str] = None
408
+ ) -> List[Tuple[Checkpoint, Number]]:
409
+ """Get all checkpoints and a specified metric of a trial.
410
+
411
+ Args:
412
+ trial: The log directory of a trial, or a trial instance.
413
+ metric: key for trial info to return, e.g. "mean_accuracy".
414
+ "training_iteration" is used by default if no value was
415
+ passed to ``self.default_metric``.
416
+
417
+ Returns:
418
+ List of [Checkpoint, metric] for all checkpoints of the trial.
419
+ """
420
+ metric = metric or self.default_metric or TRAINING_ITERATION
421
+
422
+ best_checkpoint_results = (
423
+ trial.run_metadata.checkpoint_manager.best_checkpoint_results
424
+ )
425
+ best_checkpoints = [
426
+ (checkpoint_result.checkpoint, checkpoint_result.metrics)
427
+ for checkpoint_result in best_checkpoint_results
428
+ ]
429
+ # Support nested metrics given as flattened strings, e.g.
430
+ # "info/learner/default_policy/policy_loss".
431
+ return [
432
+ (checkpoint, unflattened_lookup(metric, metrics))
433
+ for checkpoint, metrics in best_checkpoints
434
+ ]
435
+
436
+ def get_best_checkpoint(
437
+ self,
438
+ trial: Trial,
439
+ metric: Optional[str] = None,
440
+ mode: Optional[str] = None,
441
+ ) -> Optional[Checkpoint]:
442
+ """Gets best persistent checkpoint path of provided trial.
443
+
444
+ Any checkpoints with an associated metric value of ``nan`` will be filtered out.
445
+
446
+ Args:
447
+ trial: The log directory of a trial, or a trial instance.
448
+ metric: key of trial info to return, e.g. "mean_accuracy".
449
+ "training_iteration" is used by default if no value was
450
+ passed to ``self.default_metric``.
451
+ mode: One of [min, max]. Defaults to ``self.default_mode``.
452
+
453
+ Returns:
454
+ A :class:`Checkpoint <ray.train.Checkpoint>` object
455
+ """
456
+ metric = metric or self.default_metric or TRAINING_ITERATION
457
+ mode = self._validate_mode(mode)
458
+
459
+ checkpoints_and_metrics = self._get_trial_checkpoints_with_metric(trial, metric)
460
+
461
+ # Filter out nan. Sorting nan values leads to undefined behavior.
462
+ checkpoints_and_metrics = list(
463
+ filter(lambda x: not is_nan(x[1]), checkpoints_and_metrics)
464
+ )
465
+
466
+ if not checkpoints_and_metrics:
467
+ logger.error(f"No checkpoints have been found for trial {trial}.")
468
+ return None
469
+
470
+ score_order_factor = -1 if mode == "min" else 1
471
+ best_checkpoint, _ = max(
472
+ checkpoints_and_metrics, key=lambda x: score_order_factor * x[1]
473
+ )
474
+ return best_checkpoint
475
+
476
+ def get_best_trial(
477
+ self,
478
+ metric: Optional[str] = None,
479
+ mode: Optional[str] = None,
480
+ scope: str = "last",
481
+ filter_nan_and_inf: bool = True,
482
+ ) -> Optional[Trial]:
483
+ """Retrieve the best trial object.
484
+
485
+ Compares all trials' scores on ``metric``.
486
+ If ``metric`` is not specified, ``self.default_metric`` will be used.
487
+ If `mode` is not specified, ``self.default_mode`` will be used.
488
+ These values are usually initialized by passing the ``metric`` and
489
+ ``mode`` parameters to ``tune.run()``.
490
+
491
+ Args:
492
+ metric: Key for trial info to order on. Defaults to
493
+ ``self.default_metric``.
494
+ mode: One of [min, max]. Defaults to ``self.default_mode``.
495
+ scope: One of [all, last, avg, last-5-avg, last-10-avg].
496
+ If `scope=last`, only look at each trial's final step for
497
+ `metric`, and compare across trials based on `mode=[min,max]`.
498
+ If `scope=avg`, consider the simple average over all steps
499
+ for `metric` and compare across trials based on
500
+ `mode=[min,max]`. If `scope=last-5-avg` or `scope=last-10-avg`,
501
+ consider the simple average over the last 5 or 10 steps for
502
+ `metric` and compare across trials based on `mode=[min,max]`.
503
+ If `scope=all`, find each trial's min/max score for `metric`
504
+ based on `mode`, and compare trials based on `mode=[min,max]`.
505
+ filter_nan_and_inf: If True (default), NaN or infinite
506
+ values are disregarded and these trials are never selected as
507
+ the best trial.
508
+
509
+ Returns:
510
+ The best trial for the provided metric. If no trials contain the provided
511
+ metric, or if the value for the metric is NaN for all trials,
512
+ then returns None.
513
+ """
514
+ if len(self.trials) == 1:
515
+ return self.trials[0]
516
+
517
+ metric = self._validate_metric(metric)
518
+ mode = self._validate_mode(mode)
519
+
520
+ if scope not in ["all", "last", "avg", "last-5-avg", "last-10-avg"]:
521
+ raise ValueError(
522
+ "ExperimentAnalysis: attempting to get best trial for "
523
+ 'metric {} for scope {} not in ["all", "last", "avg", '
524
+ '"last-5-avg", "last-10-avg"]. '
525
+ "If you didn't pass a `metric` parameter to `tune.run()`, "
526
+ "you have to pass one when fetching the best trial.".format(
527
+ metric, scope
528
+ )
529
+ )
530
+ best_trial = None
531
+ best_metric_score = None
532
+
533
+ for trial in self.trials:
534
+ if metric not in trial.metric_analysis:
535
+ continue
536
+
537
+ if scope in ["last", "avg", "last-5-avg", "last-10-avg"]:
538
+ metric_score = trial.metric_analysis[metric][scope]
539
+ else:
540
+ metric_score = trial.metric_analysis[metric][mode]
541
+
542
+ if filter_nan_and_inf and is_nan_or_inf(metric_score):
543
+ continue
544
+
545
+ if best_metric_score is None:
546
+ best_metric_score = metric_score
547
+ best_trial = trial
548
+ continue
549
+
550
+ if (mode == "max") and (best_metric_score < metric_score):
551
+ best_metric_score = metric_score
552
+ best_trial = trial
553
+ elif (mode == "min") and (best_metric_score > metric_score):
554
+ best_metric_score = metric_score
555
+ best_trial = trial
556
+
557
+ if not best_trial:
558
+ logger.warning(
559
+ "Could not find best trial. Did you pass the correct `metric` "
560
+ "parameter?"
561
+ )
562
+ return best_trial
563
+
564
+ def get_best_config(
565
+ self,
566
+ metric: Optional[str] = None,
567
+ mode: Optional[str] = None,
568
+ scope: str = "last",
569
+ ) -> Optional[Dict]:
570
+ """Retrieve the best config corresponding to the trial.
571
+
572
+ Compares all trials' scores on `metric`.
573
+ If ``metric`` is not specified, ``self.default_metric`` will be used.
574
+ If `mode` is not specified, ``self.default_mode`` will be used.
575
+ These values are usually initialized by passing the ``metric`` and
576
+ ``mode`` parameters to ``tune.run()``.
577
+
578
+ Args:
579
+ metric: Key for trial info to order on. Defaults to
580
+ ``self.default_metric``.
581
+ mode: One of [min, max]. Defaults to ``self.default_mode``.
582
+ scope: One of [all, last, avg, last-5-avg, last-10-avg].
583
+ If `scope=last`, only look at each trial's final step for
584
+ `metric`, and compare across trials based on `mode=[min,max]`.
585
+ If `scope=avg`, consider the simple average over all steps
586
+ for `metric` and compare across trials based on
587
+ `mode=[min,max]`. If `scope=last-5-avg` or `scope=last-10-avg`,
588
+ consider the simple average over the last 5 or 10 steps for
589
+ `metric` and compare across trials based on `mode=[min,max]`.
590
+ If `scope=all`, find each trial's min/max score for `metric`
591
+ based on `mode`, and compare trials based on `mode=[min,max]`.
592
+ """
593
+ best_trial = self.get_best_trial(metric, mode, scope)
594
+ return best_trial.config if best_trial else None
595
+
596
+ def get_last_checkpoint(
597
+ self, trial=None, metric="training_iteration", mode="max"
598
+ ) -> Optional[Checkpoint]:
599
+ """Gets the last checkpoint of the provided trial,
600
+ i.e., with the highest "training_iteration".
601
+
602
+ If no trial is specified, it loads the best trial according to the
603
+ provided metric and mode (defaults to max. training iteration).
604
+
605
+ Args:
606
+ trial: If None, load the best trial automatically.
607
+ metric: If no trial is specified, use this metric to identify
608
+ the best trial and load the last checkpoint from this trial.
609
+ mode: If no trial is specified, use the metric and this mode
610
+ to identify the best trial and load the last checkpoint from it.
611
+
612
+ Returns:
613
+ Path for last checkpoint of trial
614
+ """
615
+ trial = trial or self.get_best_trial(metric, mode)
616
+ return self.get_best_checkpoint(trial, TRAINING_ITERATION, "max")
617
+
618
+ def _validate_metric(self, metric: str) -> str:
619
+ if not metric and not self.default_metric:
620
+ raise ValueError(
621
+ "No `metric` has been passed and `default_metric` has "
622
+ "not been set. Please specify the `metric` parameter."
623
+ )
624
+ return metric or self.default_metric
625
+
626
+ def _validate_mode(self, mode: str) -> str:
627
+ if not mode and not self.default_mode:
628
+ raise ValueError(
629
+ "No `mode` has been passed and `default_mode` has "
630
+ "not been set. Please specify the `mode` parameter."
631
+ )
632
+ if mode and mode not in ["min", "max"]:
633
+ raise ValueError("If set, `mode` has to be one of [min, max]")
634
+ return mode or self.default_mode
635
+
636
+ def _retrieve_rows(
637
+ self, metric: Optional[str] = None, mode: Optional[str] = None
638
+ ) -> Dict[str, Any]:
639
+ assert mode is None or mode in ["max", "min"]
640
+ assert not mode or metric
641
+ rows = {}
642
+ for path, df in self.trial_dataframes.items():
643
+ if df.empty:
644
+ continue
645
+ if metric not in df:
646
+ idx = -1
647
+ elif mode == "max":
648
+ idx = df[metric].idxmax()
649
+ elif mode == "min":
650
+ idx = df[metric].idxmin()
651
+ else:
652
+ idx = -1
653
+ try:
654
+ rows[path] = df.iloc[idx].to_dict()
655
+ except TypeError:
656
+ # idx is nan
657
+ logger.warning(
658
+ "Warning: Non-numerical value(s) encountered for {}".format(path)
659
+ )
660
+
661
+ return rows
662
+
663
+ def __getstate__(self) -> Dict[str, Any]:
664
+ """Ensure that trials are marked as stubs when pickling,
665
+ so that they can be loaded later without the trainable
666
+ being registered.
667
+ """
668
+ state = self.__dict__.copy()
669
+
670
+ def make_stub_if_needed(trial: Trial) -> Trial:
671
+ if trial.stub:
672
+ return trial
673
+ trial_copy = Trial(trial.trainable_name, stub=True)
674
+ trial_copy.__setstate__(trial.__getstate__())
675
+ return trial_copy
676
+
677
+ state["trials"] = [make_stub_if_needed(t) for t in state["trials"]]
678
+ return state
.venv/lib/python3.11/site-packages/ray/tune/impl/__init__.py ADDED
File without changes
.venv/lib/python3.11/site-packages/ray/tune/impl/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (186 Bytes). View file
 
.venv/lib/python3.11/site-packages/ray/tune/impl/__pycache__/config.cpython-311.pyc ADDED
Binary file (2.63 kB). View file
 
.venv/lib/python3.11/site-packages/ray/tune/impl/__pycache__/out_of_band_serialize_dataset.cpython-311.pyc ADDED
Binary file (2.16 kB). View file
 
.venv/lib/python3.11/site-packages/ray/tune/impl/__pycache__/placeholder.cpython-311.pyc ADDED
Binary file (12 kB). View file
 
.venv/lib/python3.11/site-packages/ray/tune/impl/__pycache__/test_utils.cpython-311.pyc ADDED
Binary file (4.84 kB). View file
 
.venv/lib/python3.11/site-packages/ray/tune/impl/__pycache__/tuner_internal.cpython-311.pyc ADDED
Binary file (28.6 kB). View file
 
.venv/lib/python3.11/site-packages/ray/tune/impl/config.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ from ray.air.config import CheckpointConfig as _CheckpointConfig
4
+ from ray.air.config import FailureConfig as _FailureConfig
5
+ from ray.air.config import RunConfig as _RunConfig
6
+ from ray.train.constants import _v2_migration_warnings_enabled
7
+ from ray.train.utils import _copy_doc, _log_deprecation_warning
8
+
9
+ # NOTE: This is just a pass-through wrapper around `ray.train.RunConfig`
10
+ # in order to detect whether the import module was correct (e.g. `ray.tune.RunConfig`).
11
+
12
+
13
+ @dataclass
14
+ @_copy_doc(_CheckpointConfig)
15
+ class CheckpointConfig(_CheckpointConfig):
16
+ pass
17
+
18
+
19
+ @dataclass
20
+ @_copy_doc(_FailureConfig)
21
+ class FailureConfig(_FailureConfig):
22
+ pass
23
+
24
+
25
+ @dataclass
26
+ @_copy_doc(_RunConfig)
27
+ class RunConfig(_RunConfig):
28
+ def __post_init__(self):
29
+ self.checkpoint_config = self.checkpoint_config or CheckpointConfig()
30
+ self.failure_config = self.failure_config or FailureConfig()
31
+
32
+ super().__post_init__()
33
+
34
+ if not isinstance(self.checkpoint_config, CheckpointConfig):
35
+ if _v2_migration_warnings_enabled():
36
+ _log_deprecation_warning(
37
+ "The `CheckpointConfig` class should be imported from `ray.tune` "
38
+ "when passing it to the Tuner. Please update your imports."
39
+ )
40
+
41
+ if not isinstance(self.failure_config, FailureConfig):
42
+ if _v2_migration_warnings_enabled():
43
+ _log_deprecation_warning(
44
+ "The `FailureConfig` class should be imported from `ray.tune` "
45
+ "when passing it to the Tuner. Please update your imports."
46
+ )
.venv/lib/python3.11/site-packages/ray/tune/impl/out_of_band_serialize_dataset.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import contextlib
2
+ import traceback
3
+
4
+ import ray
5
+
6
+
7
+ def _deserialize_and_fully_execute_if_needed(serialized_ds: bytes):
8
+ ds = ray.data.Dataset.deserialize_lineage(serialized_ds)
9
+ return ds
10
+
11
+
12
+ def _reduce(ds: ray.data.Dataset):
13
+ tb_list = traceback.format_list(traceback.extract_stack())
14
+ _already_in_out_of_band_serialization = False
15
+ for tb in tb_list:
16
+ # TODO(xwjiang): Let's make this less hacky.
17
+ if "serialize_lineage" in tb:
18
+ _already_in_out_of_band_serialization = True
19
+ break
20
+ if not _already_in_out_of_band_serialization and ds.has_serializable_lineage():
21
+ return _deserialize_and_fully_execute_if_needed, (ds.serialize_lineage(),)
22
+ else:
23
+ return ds.__reduce__()
24
+
25
+
26
+ @contextlib.contextmanager
27
+ def out_of_band_serialize_dataset():
28
+ context = ray._private.worker.global_worker.get_serialization_context()
29
+ try:
30
+ context._register_cloudpickle_reducer(ray.data.Dataset, _reduce)
31
+ yield
32
+ finally:
33
+ context._unregister_cloudpickle_reducer(ray.data.Dataset)
.venv/lib/python3.11/site-packages/ray/tune/impl/placeholder.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ from collections import defaultdict
3
+ from typing import Any, Dict, Tuple
4
+
5
+ from ray.tune.search.sample import Categorical, Domain, Function
6
+ from ray.tune.search.variant_generator import assign_value
7
+ from ray.util.annotations import DeveloperAPI
8
+
9
+ ID_HASH_LENGTH = 8
10
+
11
+
12
+ def create_resolvers_map():
13
+ return defaultdict(list)
14
+
15
+
16
+ def _id_hash(path_tuple):
17
+ """Compute a hash for the specific placeholder based on its path."""
18
+ return hashlib.sha1(str(path_tuple).encode("utf-8")).hexdigest()[:ID_HASH_LENGTH]
19
+
20
+
21
+ class _FunctionResolver:
22
+ """Replaced value for function typed objects."""
23
+
24
+ TOKEN = "__fn_ph"
25
+
26
+ def __init__(self, hash, fn):
27
+ self.hash = hash
28
+ self._fn = fn
29
+
30
+ def resolve(self, config: Dict):
31
+ """Some functions take a resolved spec dict as input.
32
+
33
+ Note: Function placeholders are independently sampled during
34
+ resolution. Therefore their random states are not restored.
35
+ """
36
+ return self._fn.sample(config=config)
37
+
38
+ def get_placeholder(self) -> str:
39
+ return (self.TOKEN, self.hash)
40
+
41
+
42
+ class _RefResolver:
43
+ """Replaced value for all other non-primitive objects."""
44
+
45
+ TOKEN = "__ref_ph"
46
+
47
+ def __init__(self, hash, value):
48
+ self.hash = hash
49
+ self._value = value
50
+
51
+ def resolve(self):
52
+ return self._value
53
+
54
+ def get_placeholder(self) -> str:
55
+ return (self.TOKEN, self.hash)
56
+
57
+
58
+ def _is_primitive(x):
59
+ """Returns True if x is a primitive type.
60
+
61
+ Primitive types are int, float, str, bool, and None.
62
+ """
63
+ return isinstance(x, (int, float, str, bool)) or x is None
64
+
65
+
66
+ @DeveloperAPI
67
+ def inject_placeholders(
68
+ config: Any,
69
+ resolvers: defaultdict,
70
+ id_prefix: Tuple = (),
71
+ path_prefix: Tuple = (),
72
+ ) -> Dict:
73
+ """Replaces reference objects contained by a config dict with placeholders.
74
+
75
+ Given a config dict, this function replaces all reference objects contained
76
+ by this dict with placeholder strings. It recursively expands nested dicts
77
+ and lists, and properly handles Tune native search objects such as Categorical
78
+ and Function.
79
+ This makes sure the config dict only contains primitive typed values, which
80
+ can then be handled by different search algorithms.
81
+
82
+ A few details about id_prefix and path_prefix. Consider the following config,
83
+ where "param1" is a simple grid search of 3 tuples.
84
+
85
+ config = {
86
+ "param1": tune.grid_search([
87
+ (Cat, None, None),
88
+ (None, Dog, None),
89
+ (None, None, Fish),
90
+ ]),
91
+ }
92
+
93
+ We will replace the 3 objects contained with placeholders. And after trial
94
+ expansion, the config may look like this:
95
+
96
+ config = {
97
+ "param1": (None, (placeholder, hash), None)
98
+ }
99
+
100
+ Now you need 2 pieces of information to resolve the placeholder. One is the
101
+ path of ("param1", 1), which tells you that the first element of the tuple
102
+ under "param1" key is a placeholder that needs to be resolved.
103
+ The other is the mapping from the placeholder to the actual object. In this
104
+ case hash -> Dog.
105
+
106
+ id and path prefixes serve exactly this purpose here. The difference between
107
+ these two is that id_prefix is the location of the value in the pre-injected
108
+ config tree. So if a value is the second option in a grid_search, it gets an
109
+ id part of 1. Injected placeholders all get unique id prefixes. path prefix
110
+ identifies a placeholder in the expanded config tree. So for example, all
111
+ options of a single grid_search will get the same path prefix. This is how
112
+ we know which location has a placeholder to be resolved in the post-expansion
113
+ tree.
114
+
115
+ Args:
116
+ config: The config dict to replace references in.
117
+ resolvers: A dict from path to replaced objects.
118
+ id_prefix: The prefix to prepend to id every single placeholders.
119
+ path_prefix: The prefix to prepend to every path identifying
120
+ potential locations of placeholders in an expanded tree.
121
+
122
+ Returns:
123
+ The config with all references replaced.
124
+ """
125
+ if isinstance(config, dict) and "grid_search" in config and len(config) == 1:
126
+ config["grid_search"] = [
127
+ # Different options gets different id prefixes.
128
+ # But we should omit appending to path_prefix because after expansion,
129
+ # this level will not be there.
130
+ inject_placeholders(choice, resolvers, id_prefix + (i,), path_prefix)
131
+ for i, choice in enumerate(config["grid_search"])
132
+ ]
133
+ return config
134
+ elif isinstance(config, dict):
135
+ return {
136
+ k: inject_placeholders(v, resolvers, id_prefix + (k,), path_prefix + (k,))
137
+ for k, v in config.items()
138
+ }
139
+ elif isinstance(config, list):
140
+ return [
141
+ inject_placeholders(elem, resolvers, id_prefix + (i,), path_prefix + (i,))
142
+ for i, elem in enumerate(config)
143
+ ]
144
+ elif isinstance(config, tuple):
145
+ return tuple(
146
+ inject_placeholders(elem, resolvers, id_prefix + (i,), path_prefix + (i,))
147
+ for i, elem in enumerate(config)
148
+ )
149
+ elif _is_primitive(config):
150
+ # Primitive types.
151
+ return config
152
+ elif isinstance(config, Categorical):
153
+ config.categories = [
154
+ # Different options gets different id prefixes.
155
+ # But we should omit appending to path_prefix because after expansion,
156
+ # this level will not be there.
157
+ inject_placeholders(choice, resolvers, id_prefix + (i,), path_prefix)
158
+ for i, choice in enumerate(config.categories)
159
+ ]
160
+ return config
161
+ elif isinstance(config, Function):
162
+ # Function type.
163
+ id_hash = _id_hash(id_prefix)
164
+ v = _FunctionResolver(id_hash, config)
165
+ resolvers[path_prefix].append(v)
166
+ return v.get_placeholder()
167
+ elif not isinstance(config, Domain):
168
+ # Other non-search space reference objects, dataset, actor handle, etc.
169
+ id_hash = _id_hash(id_prefix)
170
+ v = _RefResolver(id_hash, config)
171
+ resolvers[path_prefix].append(v)
172
+ return v.get_placeholder()
173
+ else:
174
+ # All the other cases, do nothing.
175
+ return config
176
+
177
+
178
+ def _get_placeholder(config: Any, prefix: Tuple, path: Tuple):
179
+ if not path:
180
+ return prefix, config
181
+
182
+ key = path[0]
183
+ if isinstance(config, tuple):
184
+ if config[0] in (_FunctionResolver.TOKEN, _RefResolver.TOKEN):
185
+ # Found a matching placeholder.
186
+ # Note that we do not require that the full path are consumed before
187
+ # declaring a match. Because this placeholder may be part of a nested
188
+ # search space. For example, the following config:
189
+ # config = {
190
+ # "param1": tune.grid_search([
191
+ # tune.grid_search([Object1, 2, 3]),
192
+ # tune.grid_search([Object2, 5, 6]),
193
+ # ]),
194
+ # }
195
+ # will result in placeholders under path ("param1", 0, 0).
196
+ # After expansion though, the choosen placeholder will live under path
197
+ # ("param1", 0) like this: config = {"param1": (Placeholder1, 2, 3)}
198
+ return prefix, config
199
+ elif key < len(config):
200
+ return _get_placeholder(
201
+ config[key], prefix=prefix + (path[0],), path=path[1:]
202
+ )
203
+ elif (isinstance(config, dict) and key in config) or (
204
+ isinstance(config, list) and key < len(config)
205
+ ):
206
+ # Expand config tree recursively.
207
+ return _get_placeholder(config[key], prefix=prefix + (path[0],), path=path[1:])
208
+
209
+ # Can not find a matching placeholder.
210
+ return None, None
211
+
212
+
213
+ @DeveloperAPI
214
+ def resolve_placeholders(config: Any, replaced: defaultdict):
215
+ """Replaces placeholders contained by a config dict with the original values.
216
+
217
+ Args:
218
+ config: The config to replace placeholders in.
219
+ replaced: A dict from path to replaced objects.
220
+ """
221
+
222
+ def __resolve(resolver_type, args):
223
+ for path, resolvers in replaced.items():
224
+ assert resolvers
225
+
226
+ if not isinstance(resolvers[0], resolver_type):
227
+ continue
228
+
229
+ prefix, ph = _get_placeholder(config, (), path)
230
+ if not ph:
231
+ # Represents an unchosen value. Just skip.
232
+ continue
233
+
234
+ for resolver in resolvers:
235
+ if resolver.hash != ph[1]:
236
+ continue
237
+ # Found the matching resolver.
238
+ assign_value(config, prefix, resolver.resolve(*args))
239
+
240
+ # RefResolvers first.
241
+ __resolve(_RefResolver, args=())
242
+ # Functions need to be resolved after RefResolvers, in case they are
243
+ # referencing values from the RefResolvers.
244
+ __resolve(_FunctionResolver, args=(config,))
.venv/lib/python3.11/site-packages/ray/tune/impl/test_utils.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sklearn.datasets import load_breast_cancer
2
+
3
+ from ray import tune
4
+ from ray.data import Dataset, Datasource, ReadTask, read_datasource
5
+ from ray.data.block import BlockMetadata
6
+ from ray.tune.impl.utils import execute_dataset
7
+
8
+
9
+ # TODO(xwjiang): Enable this when Clark's out-of-band-serialization is landed.
10
+ class TestDatasource(Datasource):
11
+ def prepare_read(self, parallelism: int, **read_args):
12
+ import pyarrow as pa
13
+
14
+ def load_data():
15
+ data_raw = load_breast_cancer(as_frame=True)
16
+ dataset_df = data_raw["data"]
17
+ dataset_df["target"] = data_raw["target"]
18
+ return [pa.Table.from_pandas(dataset_df)]
19
+
20
+ meta = BlockMetadata(
21
+ num_rows=None,
22
+ size_bytes=None,
23
+ schema=None,
24
+ input_files=None,
25
+ exec_stats=None,
26
+ )
27
+ return [ReadTask(load_data, meta)]
28
+
29
+
30
+ def gen_dataset_func() -> Dataset:
31
+ test_datasource = TestDatasource()
32
+ return read_datasource(test_datasource)
33
+
34
+
35
+ def test_grid_search():
36
+ ds1 = gen_dataset_func().lazy().map(lambda x: x)
37
+ ds2 = gen_dataset_func().lazy().map(lambda x: x)
38
+ assert not ds1._plan._has_final_stage_snapshot()
39
+ assert not ds2._plan._has_final_stage_snapshot()
40
+ param_space = {"train_dataset": tune.grid_search([ds1, ds2])}
41
+ execute_dataset(param_space)
42
+ executed_ds = param_space["train_dataset"]["grid_search"]
43
+ assert len(executed_ds) == 2
44
+ assert executed_ds[0]._plan._has_final_stage_snapshot()
45
+ assert executed_ds[1]._plan._has_final_stage_snapshot()
46
+
47
+
48
+ def test_choice():
49
+ ds1 = gen_dataset_func().lazy().map(lambda x: x)
50
+ ds2 = gen_dataset_func().lazy().map(lambda x: x)
51
+ assert not ds1._plan._has_final_stage_snapshot()
52
+ assert not ds2._plan._has_final_stage_snapshot()
53
+ param_space = {"train_dataset": tune.choice([ds1, ds2])}
54
+ execute_dataset(param_space)
55
+ executed_ds = param_space["train_dataset"].categories
56
+ assert len(executed_ds) == 2
57
+ assert executed_ds[0]._plan._has_final_stage_snapshot()
58
+ assert executed_ds[1]._plan._has_final_stage_snapshot()
59
+
60
+
61
+ if __name__ == "__main__":
62
+ import sys
63
+
64
+ import pytest
65
+
66
+ sys.exit(pytest.main(["-v", "-x", __file__]))
.venv/lib/python3.11/site-packages/ray/tune/impl/tuner_internal.py ADDED
@@ -0,0 +1,669 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import io
3
+ import logging
4
+ import math
5
+ from pathlib import Path
6
+ from typing import (
7
+ TYPE_CHECKING,
8
+ Any,
9
+ Callable,
10
+ Dict,
11
+ List,
12
+ Optional,
13
+ Tuple,
14
+ Type,
15
+ Union,
16
+ )
17
+
18
+ import pyarrow.fs
19
+
20
+ import ray.cloudpickle as pickle
21
+ import ray.train
22
+ from ray.air._internal.uri_utils import URI
23
+ from ray.air._internal.usage import AirEntrypoint
24
+ from ray.train import ScalingConfig
25
+ from ray.train._internal.storage import StorageContext, get_fs_and_path
26
+ from ray.train.constants import _v2_migration_warnings_enabled
27
+ from ray.train.utils import _log_deprecation_warning
28
+ from ray.tune import (
29
+ Experiment,
30
+ ExperimentAnalysis,
31
+ ResumeConfig,
32
+ RunConfig,
33
+ TuneConfig,
34
+ TuneError,
35
+ )
36
+ from ray.tune.registry import is_function_trainable
37
+ from ray.tune.result_grid import ResultGrid
38
+ from ray.tune.trainable import Trainable
39
+ from ray.tune.tune import _Config, run
40
+ from ray.tune.utils import flatten_dict
41
+ from ray.util import inspect_serializability
42
+
43
+ if TYPE_CHECKING:
44
+ from ray.train.trainer import BaseTrainer
45
+ from ray.util.queue import Queue
46
+
47
+
48
+ _TUNER_PKL = "tuner.pkl"
49
+ _TRAINABLE_KEY = "_trainable"
50
+ _CONVERTED_TRAINABLE_KEY = "_converted_trainable"
51
+ _PARAM_SPACE_KEY = "_param_space"
52
+ _EXPERIMENT_ANALYSIS_KEY = "_experiment_analysis"
53
+
54
+ logger = logging.getLogger(__name__)
55
+
56
+ TrainableType = Union[str, Callable, Type[Trainable]]
57
+ TrainableTypeOrTrainer = Union[TrainableType, "BaseTrainer"]
58
+
59
+
60
+ class TunerInternal:
61
+ """The real implementation behind external facing ``Tuner``.
62
+
63
+ The external facing ``Tuner`` multiplexes between local Tuner and remote Tuner
64
+ depending on whether in Ray client mode.
65
+
66
+ In Ray client mode, external ``Tuner`` wraps ``TunerInternal`` into a remote actor,
67
+ which is guaranteed to be placed on head node.
68
+
69
+ ``TunerInternal`` can be constructed from fresh, in which case, ``trainable`` needs
70
+ to be provided, together with optional ``param_space``, ``tune_config`` and
71
+ ``run_config``.
72
+
73
+ It can also be restored from a previous failed run (given ``restore_path``).
74
+
75
+ Args:
76
+ restore_path: The path from where the Tuner can be restored. If provided, None
77
+ of the rest args are needed.
78
+ resume_config: Resume config to configure which trials to continue.
79
+ trainable: The trainable to be tuned.
80
+ param_space: Search space of the tuning job.
81
+ One thing to note is that both preprocessor and dataset can be tuned here.
82
+ tune_config: Tuning algorithm specific configs.
83
+ Refer to ray.tune.tune_config.TuneConfig for more info.
84
+ run_config: Runtime configuration that is specific to individual trials.
85
+ If passed, this will overwrite the run config passed to the Trainer,
86
+ if applicable. Refer to ray.tune.RunConfig for more info.
87
+ """
88
+
89
+ def __init__(
90
+ self,
91
+ restore_path: str = None,
92
+ storage_filesystem: Optional[pyarrow.fs.FileSystem] = None,
93
+ resume_config: Optional[ResumeConfig] = None,
94
+ trainable: Optional[TrainableTypeOrTrainer] = None,
95
+ param_space: Optional[Dict[str, Any]] = None,
96
+ tune_config: Optional[TuneConfig] = None,
97
+ run_config: Optional[RunConfig] = None,
98
+ _tuner_kwargs: Optional[Dict] = None,
99
+ _entrypoint: AirEntrypoint = AirEntrypoint.TUNER,
100
+ ):
101
+ from ray.train.trainer import BaseTrainer
102
+
103
+ if isinstance(trainable, BaseTrainer):
104
+ if _v2_migration_warnings_enabled():
105
+ _log_deprecation_warning(
106
+ "Passing a Trainer to the Tuner is deprecated. "
107
+ "See the section on hyperparameter optimization in this "
108
+ "REP for more information: "
109
+ "https://github.com/ray-project/enhancements/pull/57"
110
+ )
111
+
112
+ run_config = self._choose_run_config(
113
+ tuner_run_config=run_config,
114
+ trainer=trainable,
115
+ param_space=param_space,
116
+ )
117
+
118
+ self._tune_config = tune_config or TuneConfig()
119
+ self._run_config = copy.copy(run_config) or RunConfig()
120
+
121
+ if not isinstance(self._run_config, RunConfig):
122
+ if _v2_migration_warnings_enabled():
123
+ _log_deprecation_warning(
124
+ "The `RunConfig` class should be imported from `ray.tune` "
125
+ "when passing it to the Tuner. Please update your imports."
126
+ )
127
+
128
+ self._entrypoint = _entrypoint
129
+
130
+ # Restore from Tuner checkpoint.
131
+ if restore_path:
132
+ self._restore_from_path_or_uri(
133
+ path_or_uri=restore_path,
134
+ trainable=trainable,
135
+ overwrite_param_space=param_space,
136
+ resume_config=resume_config,
137
+ storage_filesystem=storage_filesystem,
138
+ )
139
+ return
140
+
141
+ # Start from fresh
142
+ if not trainable:
143
+ raise TuneError("You need to provide a trainable to tune.")
144
+
145
+ self.trainable = trainable
146
+ assert self.converted_trainable
147
+ self._validate_trainable(self.converted_trainable)
148
+
149
+ self.param_space = param_space
150
+
151
+ self._resume_config = None
152
+ self._is_restored = False
153
+ self._tuner_kwargs = copy.deepcopy(_tuner_kwargs) or {}
154
+ self._experiment_analysis = None
155
+
156
+ self._run_config.name = (
157
+ self._run_config.name
158
+ or StorageContext.get_experiment_dir_name(self.converted_trainable)
159
+ )
160
+ # The storage context here is only used to access the resolved
161
+ # storage fs and experiment path, in order to avoid duplicating that logic.
162
+ # This is NOT the storage context object that gets passed to remote workers.
163
+ storage = StorageContext(
164
+ storage_path=self._run_config.storage_path,
165
+ experiment_dir_name=self._run_config.name,
166
+ storage_filesystem=self._run_config.storage_filesystem,
167
+ )
168
+
169
+ fs = storage.storage_filesystem
170
+ fs.create_dir(storage.experiment_fs_path)
171
+ with fs.open_output_stream(
172
+ Path(storage.experiment_fs_path, _TUNER_PKL).as_posix()
173
+ ) as f:
174
+ f.write(pickle.dumps(self.__getstate__()))
175
+
176
+ def get_run_config(self) -> RunConfig:
177
+ return self._run_config
178
+
179
+ # For Jupyter output with Ray Client
180
+ def set_run_config_and_remote_string_queue(
181
+ self, run_config: RunConfig, string_queue: "Queue"
182
+ ):
183
+ self._run_config = run_config
184
+ self._tuner_kwargs["_remote_string_queue"] = string_queue
185
+
186
+ def clear_remote_string_queue(self):
187
+ self._tuner_kwargs.pop("_remote_string_queue", None)
188
+
189
+ def _expected_utilization(self, cpus_per_trial, cpus_total):
190
+ num_samples = self._tune_config.num_samples
191
+ if num_samples < 0: # TODO: simplify this in Tune
192
+ num_samples = math.inf
193
+ concurrent_trials = self._tune_config.max_concurrent_trials or 0
194
+ if concurrent_trials < 1: # TODO: simplify this in Tune
195
+ concurrent_trials = math.inf
196
+
197
+ actual_concurrency = min(
198
+ (
199
+ (cpus_total // cpus_per_trial) if cpus_per_trial else 0,
200
+ num_samples,
201
+ concurrent_trials,
202
+ )
203
+ )
204
+ return (actual_concurrency * cpus_per_trial) / (cpus_total + 0.001)
205
+
206
+ def _validate_trainable(
207
+ self, trainable: TrainableType, required_trainable_name: Optional[str] = None
208
+ ):
209
+ """Determines whether or not the trainable is valid.
210
+
211
+ This includes checks on the serializability of the trainable, as well
212
+ asserting that the trainable name is as expected on restoration.
213
+
214
+ This trainable name validation is needed due to an implementation detail
215
+ where the trainable name (which is differently generated depending on
216
+ the trainable type) is saved in the Trial metadata and needs to match
217
+ upon restoration. This does not affect the typical path, since `Tuner.restore`
218
+ expects the exact same trainable (which will have the same name).
219
+
220
+ Raises:
221
+ ValueError: if the trainable name does not match or if the trainable
222
+ is not serializable.
223
+ """
224
+ try:
225
+ pickle.dumps(trainable)
226
+ except TypeError as e:
227
+ sio = io.StringIO()
228
+ inspect_serializability(trainable, print_file=sio)
229
+ msg = (
230
+ "The provided trainable is not serializable, which is a requirement "
231
+ "since the trainable is serialized and deserialized when transferred "
232
+ "to remote workers. See below for a trace of the non-serializable "
233
+ "objects that were found in your trainable:\n"
234
+ f"{sio.getvalue()}"
235
+ )
236
+ raise TypeError(msg) from e
237
+
238
+ if not required_trainable_name:
239
+ return
240
+
241
+ trainable_name = Experiment.get_trainable_name(trainable)
242
+
243
+ if trainable_name != required_trainable_name:
244
+ raise ValueError(
245
+ "Invalid `trainable` input to `Tuner.restore()`. To fix this error, "
246
+ "pass in the same trainable that was used to initialize the Tuner. "
247
+ "Got a trainable with identifier "
248
+ f"'{trainable_name}' but expected '{required_trainable_name}'."
249
+ )
250
+
251
+ def _set_trainable_on_restore(
252
+ self, trainable: TrainableType, old_trainable_name: Optional[str]
253
+ ):
254
+ from ray.train.base_trainer import BaseTrainer
255
+
256
+ self.trainable = trainable
257
+ assert self.converted_trainable
258
+ self._validate_trainable(
259
+ trainable=self.converted_trainable,
260
+ required_trainable_name=old_trainable_name,
261
+ )
262
+
263
+ if isinstance(self.trainable, BaseTrainer):
264
+ # Log a warning in case the user tries to modify the
265
+ # `RunConfig` from the Trainer
266
+ trainer: BaseTrainer = self.trainable
267
+
268
+ # Only log if the Trainer has a non-default RunConfig
269
+ if trainer.run_config != RunConfig():
270
+ logger.warning(
271
+ "The Tune experiment will restore using the original run's "
272
+ "`RunConfig`. If you made any changes to the `RunConfig` "
273
+ "within the Trainer you passed into `Tuner.restore`, "
274
+ "they will be ignored in the resumed run."
275
+ )
276
+
277
+ trainer.run_config = self._run_config
278
+
279
+ def _validate_param_space_on_restore(
280
+ self,
281
+ new_param_space: Dict[str, Any],
282
+ flattened_param_space_keys: Optional[List[str]],
283
+ ):
284
+ """Determines whether the (optionally) re-specified `param_space` is valid.
285
+
286
+ This method performs very loose validation on the new param_space to
287
+ prevent users from trying to specify new hyperparameters to tune over.
288
+
289
+ Raises:
290
+ ValueError: if not all keys match the original param_space.
291
+ """
292
+ if flattened_param_space_keys is None:
293
+ # Backwards compatibility: skip validation
294
+ return
295
+
296
+ keys = sorted(flatten_dict(new_param_space).keys())
297
+ if keys != flattened_param_space_keys:
298
+ raise ValueError(
299
+ "Invalid `param_space` input to `Tuner.restore()`. To fix this error, "
300
+ "pass in the same `param_space` that was used to initialize the Tuner. "
301
+ "Only re-specify the `param_space` to refresh Ray object references "
302
+ "that no longer exist due to restoring from a new Ray cluster session. "
303
+ "It should not be used to introduce new hyperparameters to tune."
304
+ f"\n\nGot: {keys}\nExpected: {flattened_param_space_keys}"
305
+ )
306
+
307
+ def _set_param_space_on_restore(
308
+ self,
309
+ param_space: Optional[Dict[str, Any]],
310
+ flattened_param_space_keys: Optional[List[str]],
311
+ ):
312
+ self.param_space = param_space
313
+
314
+ if self.param_space is not None:
315
+ # param_space = None -> use the original param_space
316
+ self._validate_param_space_on_restore(
317
+ new_param_space=self.param_space,
318
+ flattened_param_space_keys=flattened_param_space_keys,
319
+ )
320
+
321
+ def _load_tuner_state(
322
+ self, tuner_state: Dict[str, Any]
323
+ ) -> Tuple[Optional[str], Optional[List[str]]]:
324
+ """Loads Tuner state from the previously saved `tuner.pkl`.
325
+
326
+ Args:
327
+ tuner_pkl_path: pathlib.Path of the `tuner.pkl` file saved during the
328
+ original Tuner initialization.
329
+
330
+ Returns:
331
+ tuple: of `(old_trainable_name, flattened_param_space_keys)` used for
332
+ validating the re-specified `trainable` and `param_space`.
333
+ """
334
+ # NOTE: These are magic keys used for validating restore args.
335
+ old_trainable_name = tuner_state.pop("__trainable_name", None)
336
+ flattened_param_space_keys = tuner_state.pop(
337
+ "__flattened_param_space_keys", None
338
+ )
339
+
340
+ self.__setstate__(tuner_state)
341
+
342
+ return old_trainable_name, flattened_param_space_keys
343
+
344
+ def _restore_from_path_or_uri(
345
+ self,
346
+ path_or_uri: str,
347
+ trainable: TrainableTypeOrTrainer,
348
+ overwrite_param_space: Optional[Dict[str, Any]],
349
+ resume_config: ResumeConfig,
350
+ storage_filesystem: Optional[pyarrow.fs.FileSystem],
351
+ ):
352
+ fs, fs_path = get_fs_and_path(path_or_uri, storage_filesystem)
353
+ with fs.open_input_file(Path(fs_path, _TUNER_PKL).as_posix()) as f:
354
+ tuner_state = pickle.loads(f.readall())
355
+
356
+ old_trainable_name, flattened_param_space_keys = self._load_tuner_state(
357
+ tuner_state
358
+ )
359
+
360
+ # Perform validation and set the re-specified `trainable` and `param_space`
361
+ self._set_trainable_on_restore(
362
+ trainable=trainable, old_trainable_name=old_trainable_name
363
+ )
364
+ self._set_param_space_on_restore(
365
+ param_space=overwrite_param_space,
366
+ flattened_param_space_keys=flattened_param_space_keys,
367
+ )
368
+
369
+ # Update RunConfig to reflect changes in the experiment directory
370
+ path_or_uri_obj = URI(path_or_uri)
371
+
372
+ # Infer the `storage_path` and run `name` of the restored run using the
373
+ # experiment directory.
374
+ # Ex: ~/ray_results/exp_name -> ~/ray_results, exp_name
375
+ # Ex: s3://bucket/exp_name -> s3://bucket, exp_name
376
+ self._run_config.name = path_or_uri_obj.name
377
+ self._run_config.storage_path = str(path_or_uri_obj.parent)
378
+ # Update the storage_filesystem with the one passed in on restoration, if any.
379
+ self._run_config.storage_filesystem = storage_filesystem
380
+
381
+ # Load the experiment results at the point where it left off.
382
+ try:
383
+ self._experiment_analysis = ExperimentAnalysis(
384
+ experiment_checkpoint_path=path_or_uri,
385
+ default_metric=self._tune_config.metric,
386
+ default_mode=self._tune_config.mode,
387
+ storage_filesystem=storage_filesystem,
388
+ )
389
+ except Exception:
390
+ self._experiment_analysis = None
391
+
392
+ self._resume_config = resume_config
393
+ self._is_restored = True
394
+
395
+ def _choose_run_config(
396
+ self,
397
+ tuner_run_config: Optional[RunConfig],
398
+ trainer: "BaseTrainer",
399
+ param_space: Optional[Dict[str, Any]],
400
+ ) -> RunConfig:
401
+ """Chooses which `RunConfig` to use when multiple can be passed in
402
+ through a Trainer or the Tuner itself.
403
+
404
+ Args:
405
+ tuner_run_config: The run config passed into the Tuner constructor.
406
+ trainer: The Trainer instance to use with Tune, which may have
407
+ a RunConfig specified by the user.
408
+ param_space: The param space passed to the Tuner.
409
+
410
+ Raises:
411
+ ValueError: if the `run_config` is specified as a hyperparameter.
412
+ """
413
+ if param_space and "run_config" in param_space:
414
+ raise ValueError(
415
+ "`RunConfig` cannot be tuned as part of the `param_space`! "
416
+ "Move the run config to be a parameter of the `Tuner`: "
417
+ "Tuner(..., run_config=RunConfig(...))"
418
+ )
419
+
420
+ # Both Tuner RunConfig + Trainer RunConfig --> prefer Tuner RunConfig
421
+ if tuner_run_config and trainer.run_config != ray.train.RunConfig():
422
+ logger.info(
423
+ "A `RunConfig` was passed to both the `Tuner` and the "
424
+ f"`{trainer.__class__.__name__}`. The run config passed to "
425
+ "the `Tuner` is the one that will be used."
426
+ )
427
+ return tuner_run_config
428
+
429
+ # No Tuner RunConfig -> pass the Trainer config through
430
+ # This returns either a user-specified config, or the default RunConfig
431
+ # if nothing was provided to both the Trainer or Tuner.
432
+ if not tuner_run_config:
433
+ return trainer.run_config
434
+
435
+ # Tuner RunConfig + No Trainer RunConfig --> Use the Tuner config
436
+ return tuner_run_config
437
+
438
+ def _process_scaling_config(self) -> None:
439
+ """Converts ``self._param_space["scaling_config"]`` to a dict.
440
+
441
+ The dict is converted back to a dataclass by the Trainer, after the
442
+ Tune search specification is resolved.
443
+ """
444
+ # TODO: introduce `ray.tune.sample.TuneableDataclass` and allow Tune to
445
+ # natively resolve specs with dataclasses.
446
+ scaling_config = self._param_space.get("scaling_config")
447
+ if not isinstance(scaling_config, ScalingConfig):
448
+ return
449
+ self._param_space["scaling_config"] = scaling_config.__dict__.copy()
450
+
451
+ @property
452
+ def trainable(self) -> TrainableTypeOrTrainer:
453
+ return self._trainable
454
+
455
+ @property
456
+ def converted_trainable(self) -> TrainableType:
457
+ return self._converted_trainable
458
+
459
+ @trainable.setter
460
+ def trainable(self, trainable: TrainableTypeOrTrainer):
461
+ self._trainable = trainable
462
+ self._converted_trainable = self._convert_trainable(trainable)
463
+
464
+ @property
465
+ def param_space(self) -> Optional[Dict[str, Any]]:
466
+ return self._param_space
467
+
468
+ @param_space.setter
469
+ def param_space(self, param_space: Optional[Dict[str, Any]]):
470
+ # Handle any configs that adhere to the `to_dict` interface.
471
+ # Ex: AlgorithmConfig from RLlib
472
+ if isinstance(param_space, _Config):
473
+ param_space = param_space.to_dict()
474
+
475
+ if not isinstance(param_space, dict) and param_space is not None:
476
+ raise ValueError(
477
+ "The `param_space` passed to the `Tuner` must be a dict. "
478
+ f"Got '{type(param_space)}' instead."
479
+ )
480
+
481
+ self._param_space = param_space
482
+
483
+ if param_space:
484
+ self._process_scaling_config()
485
+
486
+ def _convert_trainable(self, trainable: TrainableTypeOrTrainer) -> TrainableType:
487
+ """Converts a Trainer to a Tune trainable and saves the converted
488
+ trainable. If not using a Trainer, this leaves the trainable as is."""
489
+ from ray.train.trainer import BaseTrainer
490
+
491
+ return (
492
+ trainable.as_trainable()
493
+ if isinstance(trainable, BaseTrainer)
494
+ else trainable
495
+ )
496
+
497
+ def fit(self) -> ResultGrid:
498
+ trainable = self.converted_trainable
499
+ param_space = copy.deepcopy(self.param_space)
500
+ if not self._is_restored:
501
+ analysis = self._fit_internal(trainable, param_space)
502
+ else:
503
+ analysis = self._fit_resume(trainable, param_space)
504
+
505
+ self._experiment_analysis = analysis
506
+
507
+ return ResultGrid(self._experiment_analysis)
508
+
509
+ def get_results(self) -> ResultGrid:
510
+ if not self._experiment_analysis:
511
+ raise RuntimeError(
512
+ "Can't return results as experiment has not been run, yet. "
513
+ "Call `Tuner.fit()` to run the experiment first."
514
+ )
515
+ return ResultGrid(self._experiment_analysis)
516
+
517
+ def _get_tune_run_arguments(self, trainable: TrainableType) -> Dict[str, Any]:
518
+ """Get tune.run arguments common for both new and resumed runs."""
519
+ # Avoid overwriting the originally configured checkpoint config.
520
+ checkpoint_config = copy.deepcopy(self._run_config.checkpoint_config)
521
+
522
+ if checkpoint_config.checkpoint_frequency:
523
+ # Function trainables (and thus most of our trainers) usually don't handle
524
+ # this argument.
525
+ handle_checkpoint_freq = getattr(
526
+ trainable, "_handles_checkpoint_freq", None
527
+ )
528
+ if handle_checkpoint_freq is False:
529
+ # If we specifically know this trainable doesn't support the
530
+ # argument, raise an error
531
+ raise ValueError(
532
+ "You passed `checkpoint_frequency="
533
+ f"{checkpoint_config.checkpoint_frequency}` to your "
534
+ "CheckpointConfig, but this trainer does not support "
535
+ "this argument. If you passed in a Trainer that takes in a "
536
+ "custom training loop, you will need to "
537
+ "report a checkpoint every `checkpoint_frequency` iterations "
538
+ "within your training loop using "
539
+ "`ray.train.report(metrics=..., checkpoint=...)` "
540
+ "to get this behavior."
541
+ )
542
+ elif handle_checkpoint_freq is True:
543
+ # If we specifically support it, it's handled in the training loop,
544
+ # so we disable tune's bookkeeping.
545
+ checkpoint_config.checkpoint_frequency = 0
546
+ # Otherwise, the trainable is not a Trainer and we just keep the
547
+ # user-supplied value.
548
+ # Function trainables will raise a runtime error later if set > 0
549
+ if checkpoint_config.checkpoint_at_end is not None:
550
+ # Again, function trainables usually don't handle this argument.
551
+ handle_cp_at_end = getattr(trainable, "_handles_checkpoint_at_end", None)
552
+ if handle_cp_at_end is False:
553
+ # If we specifically know we don't support it, raise an error.
554
+ raise ValueError(
555
+ "You passed `checkpoint_at_end="
556
+ f"{checkpoint_config.checkpoint_at_end}` "
557
+ "to your CheckpointConfig, but this trainer does not support "
558
+ "this argument. If you passed in a Trainer that takes in a "
559
+ "custom training loop, you should include one last call to "
560
+ "`ray.train.report(metrics=..., checkpoint=...)` "
561
+ "at the end of your training loop to get this behavior."
562
+ )
563
+ elif handle_cp_at_end is True:
564
+ # If we specifically support it, it's handled in the training loop,
565
+ # so we disable tune's internal bookkeeping.
566
+ checkpoint_config.checkpoint_at_end = False
567
+ # If this is a user-defined trainable, just keep the value
568
+ # Function trainables will raise a runtime error later if set to True
569
+ else:
570
+ # Set default to False for function trainables and True for everything else
571
+ if is_function_trainable(trainable):
572
+ checkpoint_config.checkpoint_at_end = False
573
+ else:
574
+ checkpoint_config.checkpoint_at_end = True
575
+
576
+ return dict(
577
+ storage_path=self._run_config.storage_path,
578
+ storage_filesystem=self._run_config.storage_filesystem,
579
+ name=self._run_config.name,
580
+ mode=self._tune_config.mode,
581
+ metric=self._tune_config.metric,
582
+ callbacks=self._run_config.callbacks,
583
+ sync_config=self._run_config.sync_config,
584
+ stop=self._run_config.stop,
585
+ max_failures=self._run_config.failure_config.max_failures,
586
+ checkpoint_config=checkpoint_config,
587
+ raise_on_failed_trial=False,
588
+ fail_fast=(self._run_config.failure_config.fail_fast),
589
+ progress_reporter=self._run_config.progress_reporter,
590
+ verbose=self._run_config.verbose,
591
+ reuse_actors=self._tune_config.reuse_actors,
592
+ max_concurrent_trials=self._tune_config.max_concurrent_trials,
593
+ time_budget_s=self._tune_config.time_budget_s,
594
+ trial_name_creator=self._tune_config.trial_name_creator,
595
+ trial_dirname_creator=self._tune_config.trial_dirname_creator,
596
+ _entrypoint=self._entrypoint,
597
+ # Deprecated
598
+ chdir_to_trial_dir=self._tune_config.chdir_to_trial_dir,
599
+ )
600
+
601
+ def _fit_internal(
602
+ self, trainable: TrainableType, param_space: Optional[Dict[str, Any]]
603
+ ) -> ExperimentAnalysis:
604
+ """Fitting for a fresh Tuner."""
605
+ args = {
606
+ **self._get_tune_run_arguments(trainable),
607
+ **dict(
608
+ run_or_experiment=trainable,
609
+ config=param_space,
610
+ num_samples=self._tune_config.num_samples,
611
+ search_alg=self._tune_config.search_alg,
612
+ scheduler=self._tune_config.scheduler,
613
+ log_to_file=self._run_config.log_to_file,
614
+ ),
615
+ **self._tuner_kwargs,
616
+ }
617
+ analysis = run(
618
+ **args,
619
+ )
620
+ self.clear_remote_string_queue()
621
+ return analysis
622
+
623
+ def _fit_resume(
624
+ self, trainable: TrainableType, param_space: Optional[Dict[str, Any]]
625
+ ) -> ExperimentAnalysis:
626
+ """Fitting for a restored Tuner."""
627
+ assert self._resume_config
628
+
629
+ args = {
630
+ **self._get_tune_run_arguments(trainable),
631
+ **dict(
632
+ run_or_experiment=trainable,
633
+ config=param_space,
634
+ resume_config=self._resume_config,
635
+ search_alg=self._tune_config.search_alg,
636
+ scheduler=self._tune_config.scheduler,
637
+ ),
638
+ **self._tuner_kwargs,
639
+ }
640
+ analysis = run(**args)
641
+ self.clear_remote_string_queue()
642
+ return analysis
643
+
644
+ def __getstate__(self):
645
+ state = self.__dict__.copy()
646
+ state["_tuner_kwargs"] = state["_tuner_kwargs"].copy()
647
+ state["_tuner_kwargs"].pop("_remote_string_queue", None)
648
+ state.pop(_TRAINABLE_KEY, None)
649
+ trainable = state.pop(_CONVERTED_TRAINABLE_KEY, None)
650
+ param_space = state.pop(_PARAM_SPACE_KEY, None)
651
+ state.pop(_EXPERIMENT_ANALYSIS_KEY, None)
652
+
653
+ state["__trainable_name"] = (
654
+ Experiment.get_trainable_name(trainable) if trainable else None
655
+ )
656
+ state["__flattened_param_space_keys"] = (
657
+ sorted(flatten_dict(param_space).keys())
658
+ if param_space is not None
659
+ else None
660
+ )
661
+
662
+ return state
663
+
664
+ def __setstate__(self, state):
665
+ # Make sure the magic metadata gets removed first.
666
+ state.pop("__flattened_param_space_keys", None)
667
+ state.pop("__trainable_name", None)
668
+
669
+ self.__dict__.update(state)
.venv/lib/python3.11/site-packages/ray/tune/schedulers/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (3.35 kB). View file
 
.venv/lib/python3.11/site-packages/ray/tune/schedulers/__pycache__/resource_changing_scheduler.cpython-311.pyc ADDED
Binary file (41 kB). View file
 
.venv/lib/python3.11/site-packages/ray/tune/search/__init__.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ray._private.utils import get_function_args
2
+ from ray.tune.search.basic_variant import BasicVariantGenerator
3
+ from ray.tune.search.concurrency_limiter import ConcurrencyLimiter
4
+ from ray.tune.search.repeater import Repeater
5
+ from ray.tune.search.search_algorithm import SearchAlgorithm
6
+ from ray.tune.search.search_generator import SearchGenerator
7
+ from ray.tune.search.searcher import Searcher
8
+ from ray.tune.search.variant_generator import grid_search
9
+ from ray.util import PublicAPI
10
+
11
+
12
+ def _import_variant_generator():
13
+ return BasicVariantGenerator
14
+
15
+
16
+ def _import_ax_search():
17
+ from ray.tune.search.ax.ax_search import AxSearch
18
+
19
+ return AxSearch
20
+
21
+
22
+ def _import_hyperopt_search():
23
+ from ray.tune.search.hyperopt.hyperopt_search import HyperOptSearch
24
+
25
+ return HyperOptSearch
26
+
27
+
28
+ def _import_bayesopt_search():
29
+ from ray.tune.search.bayesopt.bayesopt_search import BayesOptSearch
30
+
31
+ return BayesOptSearch
32
+
33
+
34
+ def _import_bohb_search():
35
+ from ray.tune.search.bohb.bohb_search import TuneBOHB
36
+
37
+ return TuneBOHB
38
+
39
+
40
+ def _import_nevergrad_search():
41
+ from ray.tune.search.nevergrad.nevergrad_search import NevergradSearch
42
+
43
+ return NevergradSearch
44
+
45
+
46
+ def _import_optuna_search():
47
+ from ray.tune.search.optuna.optuna_search import OptunaSearch
48
+
49
+ return OptunaSearch
50
+
51
+
52
+ def _import_zoopt_search():
53
+ from ray.tune.search.zoopt.zoopt_search import ZOOptSearch
54
+
55
+ return ZOOptSearch
56
+
57
+
58
+ def _import_hebo_search():
59
+ from ray.tune.search.hebo.hebo_search import HEBOSearch
60
+
61
+ return HEBOSearch
62
+
63
+
64
+ SEARCH_ALG_IMPORT = {
65
+ "variant_generator": _import_variant_generator,
66
+ "random": _import_variant_generator,
67
+ "ax": _import_ax_search,
68
+ "hyperopt": _import_hyperopt_search,
69
+ "bayesopt": _import_bayesopt_search,
70
+ "bohb": _import_bohb_search,
71
+ "nevergrad": _import_nevergrad_search,
72
+ "optuna": _import_optuna_search,
73
+ "zoopt": _import_zoopt_search,
74
+ "hebo": _import_hebo_search,
75
+ }
76
+
77
+
78
+ @PublicAPI(stability="beta")
79
+ def create_searcher(
80
+ search_alg,
81
+ **kwargs,
82
+ ):
83
+ """Instantiate a search algorithm based on the given string.
84
+
85
+ This is useful for swapping between different search algorithms.
86
+
87
+ Args:
88
+ search_alg: The search algorithm to use.
89
+ metric: The training result objective value attribute. Stopping
90
+ procedures will use this attribute.
91
+ mode: One of {min, max}. Determines whether objective is
92
+ minimizing or maximizing the metric attribute.
93
+ **kwargs: Additional parameters.
94
+ These keyword arguments will be passed to the initialization
95
+ function of the chosen class.
96
+ Returns:
97
+ ray.tune.search.Searcher: The search algorithm.
98
+ Example:
99
+ >>> from ray import tune # doctest: +SKIP
100
+ >>> search_alg = tune.create_searcher('ax') # doctest: +SKIP
101
+ """
102
+
103
+ search_alg = search_alg.lower()
104
+ if search_alg not in SEARCH_ALG_IMPORT:
105
+ raise ValueError(
106
+ f"The `search_alg` argument must be one of "
107
+ f"{list(SEARCH_ALG_IMPORT)}. "
108
+ f"Got: {search_alg}"
109
+ )
110
+
111
+ SearcherClass = SEARCH_ALG_IMPORT[search_alg]()
112
+
113
+ search_alg_args = get_function_args(SearcherClass)
114
+ trimmed_kwargs = {k: v for k, v in kwargs.items() if k in search_alg_args}
115
+
116
+ return SearcherClass(**trimmed_kwargs)
117
+
118
+
119
+ UNRESOLVED_SEARCH_SPACE = str(
120
+ "You passed a `{par}` parameter to {cls} that contained unresolved search "
121
+ "space definitions. {cls} should however be instantiated with fully "
122
+ "configured search spaces only. To use Ray Tune's automatic search space "
123
+ "conversion, pass the space definition as part of the `param_space` argument "
124
+ "to `tune.Tuner()` instead."
125
+ )
126
+
127
+ UNDEFINED_SEARCH_SPACE = str(
128
+ "Trying to sample a configuration from {cls}, but no search "
129
+ "space has been defined. Either pass the `{space}` argument when "
130
+ "instantiating the search algorithm, or pass a `param_space` to "
131
+ "`tune.Tuner()`."
132
+ )
133
+
134
+ UNDEFINED_METRIC_MODE = str(
135
+ "Trying to sample a configuration from {cls}, but the `metric` "
136
+ "({metric}) or `mode` ({mode}) parameters have not been set. "
137
+ "Either pass these arguments when instantiating the search algorithm, "
138
+ "or pass them to `tune.TuneConfig()`."
139
+ )
140
+
141
+
142
+ __all__ = [
143
+ "SearchAlgorithm",
144
+ "Searcher",
145
+ "ConcurrencyLimiter",
146
+ "Repeater",
147
+ "BasicVariantGenerator",
148
+ "grid_search",
149
+ "SearchGenerator",
150
+ "UNRESOLVED_SEARCH_SPACE",
151
+ "UNDEFINED_SEARCH_SPACE",
152
+ "UNDEFINED_METRIC_MODE",
153
+ ]
.venv/lib/python3.11/site-packages/ray/tune/search/__pycache__/concurrency_limiter.cpython-311.pyc ADDED
Binary file (9.22 kB). View file
 
.venv/lib/python3.11/site-packages/ray/tune/search/__pycache__/sample.cpython-311.pyc ADDED
Binary file (40.3 kB). View file
 
.venv/lib/python3.11/site-packages/ray/tune/search/__pycache__/search_algorithm.cpython-311.pyc ADDED
Binary file (5.87 kB). View file
 
.venv/lib/python3.11/site-packages/ray/tune/search/__pycache__/search_generator.cpython-311.pyc ADDED
Binary file (12.4 kB). View file
 
.venv/lib/python3.11/site-packages/ray/tune/search/_mock.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Optional
2
+
3
+ from ray.tune.experiment import Trial
4
+ from ray.tune.search import ConcurrencyLimiter, Searcher
5
+ from ray.tune.search.search_generator import SearchGenerator
6
+
7
+
8
+ class _MockSearcher(Searcher):
9
+ def __init__(self, **kwargs):
10
+ self.live_trials = {}
11
+ self.counter = {"result": 0, "complete": 0}
12
+ self.final_results = []
13
+ self.stall = False
14
+ self.results = []
15
+ super(_MockSearcher, self).__init__(**kwargs)
16
+
17
+ def suggest(self, trial_id: str):
18
+ if not self.stall:
19
+ self.live_trials[trial_id] = 1
20
+ return {"test_variable": 2}
21
+ return None
22
+
23
+ def on_trial_result(self, trial_id: str, result: Dict):
24
+ self.counter["result"] += 1
25
+ self.results += [result]
26
+
27
+ def on_trial_complete(
28
+ self, trial_id: str, result: Optional[Dict] = None, error: bool = False
29
+ ):
30
+ self.counter["complete"] += 1
31
+ if result:
32
+ self._process_result(result)
33
+ if trial_id in self.live_trials:
34
+ del self.live_trials[trial_id]
35
+
36
+ def _process_result(self, result: Dict):
37
+ self.final_results += [result]
38
+
39
+
40
+ class _MockSuggestionAlgorithm(SearchGenerator):
41
+ def __init__(self, max_concurrent: Optional[int] = None, **kwargs):
42
+ self.searcher = _MockSearcher(**kwargs)
43
+ if max_concurrent:
44
+ self.searcher = ConcurrencyLimiter(
45
+ self.searcher, max_concurrent=max_concurrent
46
+ )
47
+ super(_MockSuggestionAlgorithm, self).__init__(self.searcher)
48
+
49
+ @property
50
+ def live_trials(self) -> List[Trial]:
51
+ return self.searcher.live_trials
52
+
53
+ @property
54
+ def results(self) -> List[Dict]:
55
+ return self.searcher.results
.venv/lib/python3.11/site-packages/ray/tune/search/ax/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (291 Bytes). View file
 
.venv/lib/python3.11/site-packages/ray/tune/search/ax/__pycache__/ax_search.cpython-311.pyc ADDED
Binary file (19.1 kB). View file
 
.venv/lib/python3.11/site-packages/ray/tune/search/basic_variant.py ADDED
@@ -0,0 +1,421 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import itertools
3
+ import os
4
+ import uuid
5
+ import warnings
6
+ from pathlib import Path
7
+ from typing import TYPE_CHECKING, Dict, List, Optional, Union
8
+
9
+ import numpy as np
10
+
11
+ from ray.air._internal.usage import tag_searcher
12
+ from ray.tune.error import TuneError
13
+ from ray.tune.experiment.config_parser import _create_trial_from_spec, _make_parser
14
+ from ray.tune.search.sample import _BackwardsCompatibleNumpyRng, np_random_generator
15
+ from ray.tune.search.search_algorithm import SearchAlgorithm
16
+ from ray.tune.search.variant_generator import (
17
+ _count_spec_samples,
18
+ _count_variants,
19
+ _flatten_resolved_vars,
20
+ _get_preset_variants,
21
+ format_vars,
22
+ generate_variants,
23
+ )
24
+ from ray.tune.utils.util import _atomic_save, _load_newest_checkpoint
25
+ from ray.util import PublicAPI
26
+
27
+ if TYPE_CHECKING:
28
+ from ray.tune.experiment import Experiment
29
+
30
+ SERIALIZATION_THRESHOLD = 1e6
31
+
32
+
33
+ class _VariantIterator:
34
+ """Iterates over generated variants from the search space.
35
+
36
+ This object also toggles between lazy evaluation and
37
+ eager evaluation of samples. If lazy evaluation is enabled,
38
+ this object cannot be serialized.
39
+ """
40
+
41
+ def __init__(self, iterable, lazy_eval=False):
42
+ self.lazy_eval = lazy_eval
43
+ self.iterable = iterable
44
+ self._has_next = True
45
+ if lazy_eval:
46
+ self._load_value()
47
+ else:
48
+ self.iterable = list(iterable)
49
+ self._has_next = bool(self.iterable)
50
+
51
+ def _load_value(self):
52
+ try:
53
+ self.next_value = next(self.iterable)
54
+ except StopIteration:
55
+ self._has_next = False
56
+
57
+ def has_next(self):
58
+ return self._has_next
59
+
60
+ def __next__(self):
61
+ if self.lazy_eval:
62
+ current_value = self.next_value
63
+ self._load_value()
64
+ return current_value
65
+ current_value = self.iterable.pop(0)
66
+ self._has_next = bool(self.iterable)
67
+ return current_value
68
+
69
+
70
+ class _TrialIterator:
71
+ """Generates trials from the spec.
72
+
73
+ Args:
74
+ uuid_prefix: Used in creating the trial name.
75
+ num_samples: Number of samples from distribution
76
+ (same as tune.TuneConfig).
77
+ unresolved_spec: Experiment specification
78
+ that might have unresolved distributions.
79
+ constant_grid_search: Should random variables be sampled
80
+ first before iterating over grid variants (True) or not (False).
81
+ points_to_evaluate: Configurations that will be tried out without sampling.
82
+ lazy_eval: Whether variants should be generated
83
+ lazily or eagerly. This is toggled depending
84
+ on the size of the grid search.
85
+ start: index at which to start counting trials.
86
+ random_state (int | np.random.Generator | np.random.RandomState):
87
+ Seed or numpy random generator to use for reproducible results.
88
+ If None (default), will use the global numpy random generator
89
+ (``np.random``). Please note that full reproducibility cannot
90
+ be guaranteed in a distributed enviroment.
91
+ """
92
+
93
+ def __init__(
94
+ self,
95
+ uuid_prefix: str,
96
+ num_samples: int,
97
+ unresolved_spec: dict,
98
+ constant_grid_search: bool = False,
99
+ points_to_evaluate: Optional[List] = None,
100
+ lazy_eval: bool = False,
101
+ start: int = 0,
102
+ random_state: Optional[
103
+ Union[int, "np_random_generator", np.random.RandomState]
104
+ ] = None,
105
+ ):
106
+ self.parser = _make_parser()
107
+ self.num_samples = num_samples
108
+ self.uuid_prefix = uuid_prefix
109
+ self.num_samples_left = num_samples
110
+ self.unresolved_spec = unresolved_spec
111
+ self.constant_grid_search = constant_grid_search
112
+ self.points_to_evaluate = points_to_evaluate or []
113
+ self.num_points_to_evaluate = len(self.points_to_evaluate)
114
+ self.counter = start
115
+ self.lazy_eval = lazy_eval
116
+ self.variants = None
117
+ self.random_state = random_state
118
+
119
+ def create_trial(self, resolved_vars, spec):
120
+ trial_id = self.uuid_prefix + ("%05d" % self.counter)
121
+ experiment_tag = str(self.counter)
122
+ # Always append resolved vars to experiment tag?
123
+ if resolved_vars:
124
+ experiment_tag += "_{}".format(format_vars(resolved_vars))
125
+ self.counter += 1
126
+ return _create_trial_from_spec(
127
+ spec,
128
+ self.parser,
129
+ evaluated_params=_flatten_resolved_vars(resolved_vars),
130
+ trial_id=trial_id,
131
+ experiment_tag=experiment_tag,
132
+ )
133
+
134
+ def __next__(self):
135
+ """Generates Trial objects with the variant generation process.
136
+
137
+ Uses a fixed point iteration to resolve variants. All trials
138
+ should be able to be generated at once.
139
+
140
+ See also: `ray.tune.search.variant_generator`.
141
+
142
+ Returns:
143
+ Trial object
144
+ """
145
+
146
+ if "run" not in self.unresolved_spec:
147
+ raise TuneError("Must specify `run` in {}".format(self.unresolved_spec))
148
+
149
+ if self.variants and self.variants.has_next():
150
+ # This block will be skipped upon instantiation.
151
+ # `variants` will be set later after the first loop.
152
+ resolved_vars, spec = next(self.variants)
153
+ return self.create_trial(resolved_vars, spec)
154
+
155
+ if self.points_to_evaluate:
156
+ config = self.points_to_evaluate.pop(0)
157
+ self.num_samples_left -= 1
158
+ self.variants = _VariantIterator(
159
+ _get_preset_variants(
160
+ self.unresolved_spec,
161
+ config,
162
+ constant_grid_search=self.constant_grid_search,
163
+ random_state=self.random_state,
164
+ ),
165
+ lazy_eval=self.lazy_eval,
166
+ )
167
+ resolved_vars, spec = next(self.variants)
168
+ return self.create_trial(resolved_vars, spec)
169
+ elif self.num_samples_left > 0:
170
+ self.variants = _VariantIterator(
171
+ generate_variants(
172
+ self.unresolved_spec,
173
+ constant_grid_search=self.constant_grid_search,
174
+ random_state=self.random_state,
175
+ ),
176
+ lazy_eval=self.lazy_eval,
177
+ )
178
+ self.num_samples_left -= 1
179
+ resolved_vars, spec = next(self.variants)
180
+ return self.create_trial(resolved_vars, spec)
181
+ else:
182
+ raise StopIteration
183
+
184
+ def __iter__(self):
185
+ return self
186
+
187
+
188
+ @PublicAPI
189
+ class BasicVariantGenerator(SearchAlgorithm):
190
+ """Uses Tune's variant generation for resolving variables.
191
+
192
+ This is the default search algorithm used if no other search algorithm
193
+ is specified.
194
+
195
+
196
+ Args:
197
+ points_to_evaluate: Initial parameter suggestions to be run
198
+ first. This is for when you already have some good parameters
199
+ you want to run first to help the algorithm make better suggestions
200
+ for future parameters. Needs to be a list of dicts containing the
201
+ configurations.
202
+ max_concurrent: Maximum number of concurrently running trials.
203
+ If 0 (default), no maximum is enforced.
204
+ constant_grid_search: If this is set to ``True``, Ray Tune will
205
+ *first* try to sample random values and keep them constant over
206
+ grid search parameters. If this is set to ``False`` (default),
207
+ Ray Tune will sample new random parameters in each grid search
208
+ condition.
209
+ random_state:
210
+ Seed or numpy random generator to use for reproducible results.
211
+ If None (default), will use the global numpy random generator
212
+ (``np.random``). Please note that full reproducibility cannot
213
+ be guaranteed in a distributed environment.
214
+
215
+
216
+ Example:
217
+
218
+ .. code-block:: python
219
+
220
+ from ray import tune
221
+
222
+ # This will automatically use the `BasicVariantGenerator`
223
+ tuner = tune.Tuner(
224
+ lambda config: config["a"] + config["b"],
225
+ tune_config=tune.TuneConfig(
226
+ num_samples=4
227
+ ),
228
+ param_space={
229
+ "a": tune.grid_search([1, 2]),
230
+ "b": tune.randint(0, 3)
231
+ },
232
+ )
233
+ tuner.fit()
234
+
235
+ In the example above, 8 trials will be generated: For each sample
236
+ (``4``), each of the grid search variants for ``a`` will be sampled
237
+ once. The ``b`` parameter will be sampled randomly.
238
+
239
+ The generator accepts a pre-set list of points that should be evaluated.
240
+ The points will replace the first samples of each experiment passed to
241
+ the ``BasicVariantGenerator``.
242
+
243
+ Each point will replace one sample of the specified ``num_samples``. If
244
+ grid search variables are overwritten with the values specified in the
245
+ presets, the number of samples will thus be reduced.
246
+
247
+ Example:
248
+
249
+ .. code-block:: python
250
+
251
+ from ray import tune
252
+ from ray.tune.search.basic_variant import BasicVariantGenerator
253
+
254
+ tuner = tune.Tuner(
255
+ lambda config: config["a"] + config["b"],
256
+ tune_config=tune.TuneConfig(
257
+ search_alg=BasicVariantGenerator(points_to_evaluate=[
258
+ {"a": 2, "b": 2},
259
+ {"a": 1},
260
+ {"b": 2}
261
+ ]),
262
+ num_samples=4
263
+ ),
264
+ param_space={
265
+ "a": tune.grid_search([1, 2]),
266
+ "b": tune.randint(0, 3)
267
+ },
268
+ )
269
+ tuner.fit()
270
+
271
+ The example above will produce six trials via four samples:
272
+
273
+ - The first sample will produce one trial with ``a=2`` and ``b=2``.
274
+ - The second sample will produce one trial with ``a=1`` and ``b`` sampled
275
+ randomly
276
+ - The third sample will produce two trials, one for each grid search
277
+ value of ``a``. It will be ``b=2`` for both of these trials.
278
+ - The fourth sample will produce two trials, one for each grid search
279
+ value of ``a``. ``b`` will be sampled randomly and independently for
280
+ both of these trials.
281
+
282
+ """
283
+
284
+ CKPT_FILE_TMPL = "basic-variant-state-{}.json"
285
+
286
+ def __init__(
287
+ self,
288
+ points_to_evaluate: Optional[List[Dict]] = None,
289
+ max_concurrent: int = 0,
290
+ constant_grid_search: bool = False,
291
+ random_state: Optional[
292
+ Union[int, "np_random_generator", np.random.RandomState]
293
+ ] = None,
294
+ ):
295
+ tag_searcher(self)
296
+ self._trial_generator = []
297
+ self._iterators = []
298
+ self._trial_iter = None
299
+ self._finished = False
300
+ self._random_state = _BackwardsCompatibleNumpyRng(random_state)
301
+
302
+ self._points_to_evaluate = points_to_evaluate or []
303
+
304
+ # Unique prefix for all trials generated, e.g., trial ids start as
305
+ # 2f1e_00001, 2f1ef_00002, 2f1ef_0003, etc. Overridable for testing.
306
+ force_test_uuid = os.environ.get("_TEST_TUNE_TRIAL_UUID")
307
+ if force_test_uuid:
308
+ self._uuid_prefix = force_test_uuid + "_"
309
+ else:
310
+ self._uuid_prefix = str(uuid.uuid1().hex)[:5] + "_"
311
+
312
+ self._total_samples = 0
313
+ self.max_concurrent = max_concurrent
314
+ self._constant_grid_search = constant_grid_search
315
+ self._live_trials = set()
316
+
317
+ @property
318
+ def total_samples(self):
319
+ return self._total_samples
320
+
321
+ def add_configurations(
322
+ self, experiments: Union["Experiment", List["Experiment"], Dict[str, Dict]]
323
+ ):
324
+ """Chains generator given experiment specifications.
325
+
326
+ Arguments:
327
+ experiments: Experiments to run.
328
+ """
329
+ from ray.tune.experiment import _convert_to_experiment_list
330
+
331
+ experiment_list = _convert_to_experiment_list(experiments)
332
+
333
+ for experiment in experiment_list:
334
+ grid_vals = _count_spec_samples(experiment.spec, num_samples=1)
335
+ lazy_eval = grid_vals > SERIALIZATION_THRESHOLD
336
+ if lazy_eval:
337
+ warnings.warn(
338
+ f"The number of pre-generated samples ({grid_vals}) "
339
+ "exceeds the serialization threshold "
340
+ f"({int(SERIALIZATION_THRESHOLD)}). Resume ability is "
341
+ "disabled. To fix this, reduce the number of "
342
+ "dimensions/size of the provided grid search."
343
+ )
344
+
345
+ previous_samples = self._total_samples
346
+ points_to_evaluate = copy.deepcopy(self._points_to_evaluate)
347
+ self._total_samples += _count_variants(experiment.spec, points_to_evaluate)
348
+ iterator = _TrialIterator(
349
+ uuid_prefix=self._uuid_prefix,
350
+ num_samples=experiment.spec.get("num_samples", 1),
351
+ unresolved_spec=experiment.spec,
352
+ constant_grid_search=self._constant_grid_search,
353
+ points_to_evaluate=points_to_evaluate,
354
+ lazy_eval=lazy_eval,
355
+ start=previous_samples,
356
+ random_state=self._random_state,
357
+ )
358
+ self._iterators.append(iterator)
359
+ self._trial_generator = itertools.chain(self._trial_generator, iterator)
360
+
361
+ def next_trial(self):
362
+ """Provides one Trial object to be queued into the TrialRunner.
363
+
364
+ Returns:
365
+ Trial: Returns a single trial.
366
+ """
367
+ if self.is_finished():
368
+ return None
369
+ if self.max_concurrent > 0 and len(self._live_trials) >= self.max_concurrent:
370
+ return None
371
+ if not self._trial_iter:
372
+ self._trial_iter = iter(self._trial_generator)
373
+ try:
374
+ trial = next(self._trial_iter)
375
+ self._live_trials.add(trial.trial_id)
376
+ return trial
377
+ except StopIteration:
378
+ self._trial_generator = []
379
+ self._trial_iter = None
380
+ self.set_finished()
381
+ return None
382
+
383
+ def on_trial_complete(
384
+ self, trial_id: str, result: Optional[Dict] = None, error: bool = False
385
+ ):
386
+ if trial_id in self._live_trials:
387
+ self._live_trials.remove(trial_id)
388
+
389
+ def get_state(self):
390
+ if any(iterator.lazy_eval for iterator in self._iterators):
391
+ return False
392
+ state = self.__dict__.copy()
393
+ del state["_trial_generator"]
394
+ return state
395
+
396
+ def set_state(self, state):
397
+ self.__dict__.update(state)
398
+ for iterator in self._iterators:
399
+ self._trial_generator = itertools.chain(self._trial_generator, iterator)
400
+
401
+ def save_to_dir(self, dirpath, session_str):
402
+ if any(iterator.lazy_eval for iterator in self._iterators):
403
+ return False
404
+ state_dict = self.get_state()
405
+ _atomic_save(
406
+ state=state_dict,
407
+ checkpoint_dir=dirpath,
408
+ file_name=self.CKPT_FILE_TMPL.format(session_str),
409
+ tmp_file_name=".tmp_generator",
410
+ )
411
+
412
+ def has_checkpoint(self, dirpath: str):
413
+ """Whether a checkpoint file exists within dirpath."""
414
+ return any(Path(dirpath).glob(self.CKPT_FILE_TMPL.format("*")))
415
+
416
+ def restore_from_dir(self, dirpath: str):
417
+ """Restores self + searcher + search wrappers from dirpath."""
418
+ state_dict = _load_newest_checkpoint(dirpath, self.CKPT_FILE_TMPL.format("*"))
419
+ if not state_dict:
420
+ raise RuntimeError("Unable to find checkpoint in {}.".format(dirpath))
421
+ self.set_state(state_dict)
.venv/lib/python3.11/site-packages/ray/tune/search/concurrency_limiter.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import logging
3
+ from typing import Dict, List, Optional
4
+
5
+ from ray.tune.search.searcher import Searcher
6
+ from ray.tune.search.util import _set_search_properties_backwards_compatible
7
+ from ray.util.annotations import PublicAPI
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ @PublicAPI
13
+ class ConcurrencyLimiter(Searcher):
14
+ """A wrapper algorithm for limiting the number of concurrent trials.
15
+
16
+ Certain Searchers have their own internal logic for limiting
17
+ the number of concurrent trials. If such a Searcher is passed to a
18
+ ``ConcurrencyLimiter``, the ``max_concurrent`` of the
19
+ ``ConcurrencyLimiter`` will override the ``max_concurrent`` value
20
+ of the Searcher. The ``ConcurrencyLimiter`` will then let the
21
+ Searcher's internal logic take over.
22
+
23
+ Args:
24
+ searcher: Searcher object that the
25
+ ConcurrencyLimiter will manage.
26
+ max_concurrent: Maximum concurrent samples from the underlying
27
+ searcher.
28
+ batch: Whether to wait for all concurrent samples
29
+ to finish before updating the underlying searcher.
30
+
31
+ Example:
32
+
33
+ .. code-block:: python
34
+
35
+ from ray.tune.search import ConcurrencyLimiter
36
+ search_alg = HyperOptSearch(metric="accuracy")
37
+ search_alg = ConcurrencyLimiter(search_alg, max_concurrent=2)
38
+ tuner = tune.Tuner(
39
+ trainable,
40
+ tune_config=tune.TuneConfig(
41
+ search_alg=search_alg
42
+ ),
43
+ )
44
+ tuner.fit()
45
+ """
46
+
47
+ def __init__(self, searcher: Searcher, max_concurrent: int, batch: bool = False):
48
+ assert type(max_concurrent) is int and max_concurrent > 0
49
+ self.searcher = searcher
50
+ self.max_concurrent = max_concurrent
51
+ self.batch = batch
52
+ self.live_trials = set()
53
+ self.num_unfinished_live_trials = 0
54
+ self.cached_results = {}
55
+ self._limit_concurrency = True
56
+
57
+ if not isinstance(searcher, Searcher):
58
+ raise RuntimeError(
59
+ f"The `ConcurrencyLimiter` only works with `Searcher` "
60
+ f"objects (got {type(searcher)}). Please try to pass "
61
+ f"`max_concurrent` to the search generator directly."
62
+ )
63
+
64
+ self._set_searcher_max_concurrency()
65
+
66
+ super(ConcurrencyLimiter, self).__init__(
67
+ metric=self.searcher.metric, mode=self.searcher.mode
68
+ )
69
+
70
+ def _set_searcher_max_concurrency(self):
71
+ # If the searcher has special logic for handling max concurrency,
72
+ # we do not do anything inside the ConcurrencyLimiter
73
+ self._limit_concurrency = not self.searcher.set_max_concurrency(
74
+ self.max_concurrent
75
+ )
76
+
77
+ def set_max_concurrency(self, max_concurrent: int) -> bool:
78
+ # Determine if this behavior is acceptable, or if it should
79
+ # raise an exception.
80
+ self.max_concurrent = max_concurrent
81
+ return True
82
+
83
+ def set_search_properties(
84
+ self, metric: Optional[str], mode: Optional[str], config: Dict, **spec
85
+ ) -> bool:
86
+ self._set_searcher_max_concurrency()
87
+ return _set_search_properties_backwards_compatible(
88
+ self.searcher.set_search_properties, metric, mode, config, **spec
89
+ )
90
+
91
+ def suggest(self, trial_id: str) -> Optional[Dict]:
92
+ if not self._limit_concurrency:
93
+ return self.searcher.suggest(trial_id)
94
+
95
+ assert (
96
+ trial_id not in self.live_trials
97
+ ), f"Trial ID {trial_id} must be unique: already found in set."
98
+ if len(self.live_trials) >= self.max_concurrent:
99
+ logger.debug(
100
+ f"Not providing a suggestion for {trial_id} due to "
101
+ "concurrency limit: %s/%s.",
102
+ len(self.live_trials),
103
+ self.max_concurrent,
104
+ )
105
+ return
106
+
107
+ suggestion = self.searcher.suggest(trial_id)
108
+ if suggestion not in (None, Searcher.FINISHED):
109
+ self.live_trials.add(trial_id)
110
+ self.num_unfinished_live_trials += 1
111
+ return suggestion
112
+
113
+ def on_trial_complete(
114
+ self, trial_id: str, result: Optional[Dict] = None, error: bool = False
115
+ ):
116
+ if not self._limit_concurrency:
117
+ return self.searcher.on_trial_complete(trial_id, result=result, error=error)
118
+
119
+ if trial_id not in self.live_trials:
120
+ return
121
+ elif self.batch:
122
+ self.cached_results[trial_id] = (result, error)
123
+ self.num_unfinished_live_trials -= 1
124
+ if self.num_unfinished_live_trials <= 0:
125
+ # Update the underlying searcher once the
126
+ # full batch is completed.
127
+ for trial_id, (result, error) in self.cached_results.items():
128
+ self.searcher.on_trial_complete(
129
+ trial_id, result=result, error=error
130
+ )
131
+ self.live_trials.remove(trial_id)
132
+ self.cached_results = {}
133
+ self.num_unfinished_live_trials = 0
134
+ else:
135
+ return
136
+ else:
137
+ self.searcher.on_trial_complete(trial_id, result=result, error=error)
138
+ self.live_trials.remove(trial_id)
139
+ self.num_unfinished_live_trials -= 1
140
+
141
+ def on_trial_result(self, trial_id: str, result: Dict) -> None:
142
+ self.searcher.on_trial_result(trial_id, result)
143
+
144
+ def add_evaluated_point(
145
+ self,
146
+ parameters: Dict,
147
+ value: float,
148
+ error: bool = False,
149
+ pruned: bool = False,
150
+ intermediate_values: Optional[List[float]] = None,
151
+ ):
152
+ return self.searcher.add_evaluated_point(
153
+ parameters, value, error, pruned, intermediate_values
154
+ )
155
+
156
+ def get_state(self) -> Dict:
157
+ state = self.__dict__.copy()
158
+ del state["searcher"]
159
+ return copy.deepcopy(state)
160
+
161
+ def set_state(self, state: Dict):
162
+ self.__dict__.update(state)
163
+
164
+ def save(self, checkpoint_path: str):
165
+ self.searcher.save(checkpoint_path)
166
+
167
+ def restore(self, checkpoint_path: str):
168
+ self.searcher.restore(checkpoint_path)
169
+
170
+ # BOHB Specific.
171
+ # TODO(team-ml): Refactor alongside HyperBandForBOHB
172
+ def on_pause(self, trial_id: str):
173
+ self.searcher.on_pause(trial_id)
174
+
175
+ def on_unpause(self, trial_id: str):
176
+ self.searcher.on_unpause(trial_id)
.venv/lib/python3.11/site-packages/ray/tune/search/hebo/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from ray.tune.search.hebo.hebo_search import HEBOSearch
2
+
3
+ __all__ = ["HEBOSearch"]
.venv/lib/python3.11/site-packages/ray/tune/search/hebo/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (299 Bytes). View file
 
.venv/lib/python3.11/site-packages/ray/tune/search/hebo/__pycache__/hebo_search.cpython-311.pyc ADDED
Binary file (21 kB). View file
 
.venv/lib/python3.11/site-packages/ray/tune/search/nevergrad/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from ray.tune.search.nevergrad.nevergrad_search import NevergradSearch
2
+
3
+ __all__ = ["NevergradSearch"]
.venv/lib/python3.11/site-packages/ray/tune/search/nevergrad/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (320 Bytes). View file
 
.venv/lib/python3.11/site-packages/ray/tune/search/nevergrad/__pycache__/nevergrad_search.cpython-311.pyc ADDED
Binary file (17.1 kB). View file
 
.venv/lib/python3.11/site-packages/ray/tune/search/nevergrad/nevergrad_search.py ADDED
@@ -0,0 +1,373 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import logging
3
+ import pickle
4
+ from typing import Dict, List, Optional, Sequence, Type, Union
5
+
6
+ from ray.tune.result import DEFAULT_METRIC
7
+ from ray.tune.search import (
8
+ UNDEFINED_METRIC_MODE,
9
+ UNDEFINED_SEARCH_SPACE,
10
+ UNRESOLVED_SEARCH_SPACE,
11
+ Searcher,
12
+ )
13
+ from ray.tune.search.sample import (
14
+ Categorical,
15
+ Domain,
16
+ Float,
17
+ Integer,
18
+ LogUniform,
19
+ Quantized,
20
+ )
21
+ from ray.tune.search.variant_generator import parse_spec_vars
22
+ from ray.tune.utils.util import flatten_dict, unflatten_dict
23
+
24
+ try:
25
+ import nevergrad as ng
26
+ from nevergrad.optimization import Optimizer
27
+ from nevergrad.optimization.base import ConfiguredOptimizer
28
+
29
+ Parameter = ng.p.Parameter
30
+ except ImportError:
31
+ ng = None
32
+ Optimizer = None
33
+ ConfiguredOptimizer = None
34
+ Parameter = None
35
+
36
+ logger = logging.getLogger(__name__)
37
+
38
+
39
+ class NevergradSearch(Searcher):
40
+ """Uses Nevergrad to optimize hyperparameters.
41
+
42
+ Nevergrad is an open source tool from Facebook for derivative free
43
+ optimization. More info can be found at:
44
+ https://github.com/facebookresearch/nevergrad.
45
+
46
+ You will need to install Nevergrad via the following command:
47
+
48
+ .. code-block:: bash
49
+
50
+ $ pip install nevergrad
51
+
52
+ Parameters:
53
+ optimizer: Optimizer class provided from Nevergrad.
54
+ See here for available optimizers:
55
+ https://facebookresearch.github.io/nevergrad/optimizers_ref.html#optimizers
56
+ This can also be an instance of a `ConfiguredOptimizer`. See the
57
+ section on configured optimizers in the above link.
58
+ optimizer_kwargs: Kwargs passed in when instantiating the `optimizer`
59
+ space: Nevergrad parametrization
60
+ to be passed to optimizer on instantiation, or list of parameter
61
+ names if you passed an optimizer object.
62
+ metric: The training result objective value attribute. If None
63
+ but a mode was passed, the anonymous metric `_metric` will be used
64
+ per default.
65
+ mode: One of {min, max}. Determines whether objective is
66
+ minimizing or maximizing the metric attribute.
67
+ points_to_evaluate: Initial parameter suggestions to be run
68
+ first. This is for when you already have some good parameters
69
+ you want to run first to help the algorithm make better suggestions
70
+ for future parameters. Needs to be a list of dicts containing the
71
+ configurations.
72
+
73
+ Tune automatically converts search spaces to Nevergrad's format:
74
+
75
+ .. code-block:: python
76
+
77
+ import nevergrad as ng
78
+
79
+ config = {
80
+ "width": tune.uniform(0, 20),
81
+ "height": tune.uniform(-100, 100),
82
+ "activation": tune.choice(["relu", "tanh"])
83
+ }
84
+
85
+ current_best_params = [{
86
+ "width": 10,
87
+ "height": 0,
88
+ "activation": relu",
89
+ }]
90
+
91
+ ng_search = NevergradSearch(
92
+ optimizer=ng.optimizers.OnePlusOne,
93
+ metric="mean_loss",
94
+ mode="min",
95
+ points_to_evaluate=current_best_params)
96
+
97
+ run(my_trainable, config=config, search_alg=ng_search)
98
+
99
+ If you would like to pass the search space manually, the code would
100
+ look like this:
101
+
102
+ .. code-block:: python
103
+
104
+ import nevergrad as ng
105
+
106
+ space = ng.p.Dict(
107
+ width=ng.p.Scalar(lower=0, upper=20),
108
+ height=ng.p.Scalar(lower=-100, upper=100),
109
+ activation=ng.p.Choice(choices=["relu", "tanh"])
110
+ )
111
+
112
+ ng_search = NevergradSearch(
113
+ optimizer=ng.optimizers.OnePlusOne,
114
+ space=space,
115
+ metric="mean_loss",
116
+ mode="min")
117
+
118
+ run(my_trainable, search_alg=ng_search)
119
+
120
+ """
121
+
122
+ def __init__(
123
+ self,
124
+ optimizer: Optional[
125
+ Union[Optimizer, Type[Optimizer], ConfiguredOptimizer]
126
+ ] = None,
127
+ optimizer_kwargs: Optional[Dict] = None,
128
+ space: Optional[Union[Dict, Parameter]] = None,
129
+ metric: Optional[str] = None,
130
+ mode: Optional[str] = None,
131
+ points_to_evaluate: Optional[List[Dict]] = None,
132
+ ):
133
+ assert (
134
+ ng is not None
135
+ ), """Nevergrad must be installed!
136
+ You can install Nevergrad with the command:
137
+ `pip install nevergrad`."""
138
+ if mode:
139
+ assert mode in ["min", "max"], "`mode` must be 'min' or 'max'."
140
+
141
+ super(NevergradSearch, self).__init__(metric=metric, mode=mode)
142
+
143
+ self._space = None
144
+ self._opt_factory = None
145
+ self._nevergrad_opt = None
146
+ self._optimizer_kwargs = optimizer_kwargs or {}
147
+
148
+ if points_to_evaluate is None:
149
+ self._points_to_evaluate = None
150
+ elif not isinstance(points_to_evaluate, Sequence):
151
+ raise ValueError(
152
+ "Invalid object type passed for `points_to_evaluate`: "
153
+ f"{type(points_to_evaluate)}. "
154
+ "Please pass a list of points (dictionaries) instead."
155
+ )
156
+ else:
157
+ self._points_to_evaluate = list(points_to_evaluate)
158
+
159
+ if isinstance(space, dict) and space:
160
+ resolved_vars, domain_vars, grid_vars = parse_spec_vars(space)
161
+ if domain_vars or grid_vars:
162
+ logger.warning(
163
+ UNRESOLVED_SEARCH_SPACE.format(par="space", cls=type(self))
164
+ )
165
+ space = self.convert_search_space(space)
166
+
167
+ if isinstance(optimizer, Optimizer):
168
+ if space is not None and not isinstance(space, list):
169
+ raise ValueError(
170
+ "If you pass a configured optimizer to Nevergrad, either "
171
+ "pass a list of parameter names or None as the `space` "
172
+ "parameter."
173
+ )
174
+ if self._optimizer_kwargs:
175
+ raise ValueError(
176
+ "If you pass in optimizer kwargs, either pass "
177
+ "an `Optimizer` subclass or an instance of "
178
+ "`ConfiguredOptimizer`."
179
+ )
180
+
181
+ self._parameters = space
182
+ self._nevergrad_opt = optimizer
183
+ elif (
184
+ inspect.isclass(optimizer) and issubclass(optimizer, Optimizer)
185
+ ) or isinstance(optimizer, ConfiguredOptimizer):
186
+ self._opt_factory = optimizer
187
+ self._parameters = None
188
+ self._space = space
189
+ else:
190
+ raise ValueError(
191
+ "The `optimizer` argument passed to NevergradSearch must be "
192
+ "either an `Optimizer` or a `ConfiguredOptimizer`."
193
+ )
194
+
195
+ self._live_trial_mapping = {}
196
+
197
+ if self._nevergrad_opt or self._space:
198
+ self._setup_nevergrad()
199
+
200
+ def _setup_nevergrad(self):
201
+ if self._opt_factory:
202
+ self._nevergrad_opt = self._opt_factory(
203
+ self._space, **self._optimizer_kwargs
204
+ )
205
+
206
+ # nevergrad.tell internally minimizes, so "max" => -1
207
+ if self._mode == "max":
208
+ self._metric_op = -1.0
209
+ elif self._mode == "min":
210
+ self._metric_op = 1.0
211
+
212
+ if self._metric is None and self._mode:
213
+ # If only a mode was passed, use anonymous metric
214
+ self._metric = DEFAULT_METRIC
215
+
216
+ if hasattr(self._nevergrad_opt, "instrumentation"): # added in v0.2.0
217
+ if self._nevergrad_opt.instrumentation.kwargs:
218
+ if self._nevergrad_opt.instrumentation.args:
219
+ raise ValueError("Instrumented optimizers should use kwargs only")
220
+ if self._parameters is not None:
221
+ raise ValueError(
222
+ "Instrumented optimizers should provide "
223
+ "None as parameter_names"
224
+ )
225
+ else:
226
+ if self._parameters is None:
227
+ raise ValueError(
228
+ "Non-instrumented optimizers should have "
229
+ "a list of parameter_names"
230
+ )
231
+ if len(self._nevergrad_opt.instrumentation.args) != 1:
232
+ raise ValueError("Instrumented optimizers should use kwargs only")
233
+ if self._parameters is not None and self._nevergrad_opt.dimension != len(
234
+ self._parameters
235
+ ):
236
+ raise ValueError(
237
+ "len(parameters_names) must match optimizer "
238
+ "dimension for non-instrumented optimizers"
239
+ )
240
+
241
+ if self._points_to_evaluate:
242
+ # Nevergrad is LIFO, so we add the points to evaluate in reverse
243
+ # order.
244
+ for i in range(len(self._points_to_evaluate) - 1, -1, -1):
245
+ self._nevergrad_opt.suggest(self._points_to_evaluate[i])
246
+
247
+ def set_search_properties(
248
+ self, metric: Optional[str], mode: Optional[str], config: Dict, **spec
249
+ ) -> bool:
250
+ if self._nevergrad_opt or self._space:
251
+ return False
252
+ space = self.convert_search_space(config)
253
+ self._space = space
254
+
255
+ if metric:
256
+ self._metric = metric
257
+ if mode:
258
+ self._mode = mode
259
+
260
+ self._setup_nevergrad()
261
+ return True
262
+
263
+ def suggest(self, trial_id: str) -> Optional[Dict]:
264
+ if not self._nevergrad_opt:
265
+ raise RuntimeError(
266
+ UNDEFINED_SEARCH_SPACE.format(
267
+ cls=self.__class__.__name__, space="space"
268
+ )
269
+ )
270
+ if not self._metric or not self._mode:
271
+ raise RuntimeError(
272
+ UNDEFINED_METRIC_MODE.format(
273
+ cls=self.__class__.__name__, metric=self._metric, mode=self._mode
274
+ )
275
+ )
276
+
277
+ suggested_config = self._nevergrad_opt.ask()
278
+
279
+ self._live_trial_mapping[trial_id] = suggested_config
280
+ # in v0.2.0+, output of ask() is a Candidate,
281
+ # with fields args and kwargs
282
+ if not suggested_config.kwargs:
283
+ if self._parameters:
284
+ return unflatten_dict(
285
+ dict(zip(self._parameters, suggested_config.args[0]))
286
+ )
287
+ return unflatten_dict(suggested_config.value)
288
+ else:
289
+ return unflatten_dict(suggested_config.kwargs)
290
+
291
+ def on_trial_complete(
292
+ self, trial_id: str, result: Optional[Dict] = None, error: bool = False
293
+ ):
294
+ """Notification for the completion of trial.
295
+
296
+ The result is internally negated when interacting with Nevergrad
297
+ so that Nevergrad Optimizers can "maximize" this value,
298
+ as it minimizes on default.
299
+ """
300
+ if result:
301
+ self._process_result(trial_id, result)
302
+
303
+ self._live_trial_mapping.pop(trial_id)
304
+
305
+ def _process_result(self, trial_id: str, result: Dict):
306
+ ng_trial_info = self._live_trial_mapping[trial_id]
307
+ self._nevergrad_opt.tell(ng_trial_info, self._metric_op * result[self._metric])
308
+
309
+ def save(self, checkpoint_path: str):
310
+ save_object = self.__dict__
311
+ with open(checkpoint_path, "wb") as outputFile:
312
+ pickle.dump(save_object, outputFile)
313
+
314
+ def restore(self, checkpoint_path: str):
315
+ with open(checkpoint_path, "rb") as inputFile:
316
+ save_object = pickle.load(inputFile)
317
+ self.__dict__.update(save_object)
318
+
319
+ @staticmethod
320
+ def convert_search_space(spec: Dict) -> Parameter:
321
+ resolved_vars, domain_vars, grid_vars = parse_spec_vars(spec)
322
+
323
+ if grid_vars:
324
+ raise ValueError(
325
+ "Grid search parameters cannot be automatically converted "
326
+ "to a Nevergrad search space."
327
+ )
328
+
329
+ # Flatten and resolve again after checking for grid search.
330
+ spec = flatten_dict(spec, prevent_delimiter=True)
331
+ resolved_vars, domain_vars, grid_vars = parse_spec_vars(spec)
332
+
333
+ def resolve_value(domain: Domain) -> Parameter:
334
+ sampler = domain.get_sampler()
335
+ if isinstance(sampler, Quantized):
336
+ logger.warning(
337
+ "Nevergrad does not support quantization. Dropped quantization."
338
+ )
339
+ sampler = sampler.get_sampler()
340
+
341
+ if isinstance(domain, Float):
342
+ if isinstance(sampler, LogUniform):
343
+ return ng.p.Log(
344
+ lower=domain.lower, upper=domain.upper, exponent=sampler.base
345
+ )
346
+ return ng.p.Scalar(lower=domain.lower, upper=domain.upper)
347
+
348
+ elif isinstance(domain, Integer):
349
+ if isinstance(sampler, LogUniform):
350
+ return ng.p.Log(
351
+ lower=domain.lower,
352
+ upper=domain.upper - 1, # Upper bound exclusive
353
+ exponent=sampler.base,
354
+ ).set_integer_casting()
355
+ return ng.p.Scalar(
356
+ lower=domain.lower,
357
+ upper=domain.upper - 1, # Upper bound exclusive
358
+ ).set_integer_casting()
359
+
360
+ elif isinstance(domain, Categorical):
361
+ return ng.p.Choice(choices=domain.categories)
362
+
363
+ raise ValueError(
364
+ "Nevergrad does not support parameters of type "
365
+ "`{}` with samplers of type `{}`".format(
366
+ type(domain).__name__, type(domain.sampler).__name__
367
+ )
368
+ )
369
+
370
+ # Parameter name is e.g. "a/b/c" for nested dicts
371
+ space = {"/".join(path): resolve_value(domain) for path, domain in domain_vars}
372
+
373
+ return ng.p.Dict(**space)
.venv/lib/python3.11/site-packages/ray/tune/search/optuna/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from ray.tune.search.optuna.optuna_search import OptunaSearch
2
+
3
+ __all__ = ["OptunaSearch"]
.venv/lib/python3.11/site-packages/ray/tune/search/optuna/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (308 Bytes). View file
 
.venv/lib/python3.11/site-packages/ray/tune/search/optuna/__pycache__/optuna_search.cpython-311.pyc ADDED
Binary file (31.7 kB). View file
 
.venv/lib/python3.11/site-packages/ray/tune/search/optuna/optuna_search.py ADDED
@@ -0,0 +1,733 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import logging
3
+ import pickle
4
+ import time
5
+ import warnings
6
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
7
+
8
+ from packaging import version
9
+
10
+ from ray.air.constants import TRAINING_ITERATION
11
+ from ray.tune.result import DEFAULT_METRIC
12
+ from ray.tune.search import (
13
+ UNDEFINED_METRIC_MODE,
14
+ UNDEFINED_SEARCH_SPACE,
15
+ UNRESOLVED_SEARCH_SPACE,
16
+ Searcher,
17
+ )
18
+ from ray.tune.search.sample import (
19
+ Categorical,
20
+ Domain,
21
+ Float,
22
+ Integer,
23
+ LogUniform,
24
+ Quantized,
25
+ Uniform,
26
+ )
27
+ from ray.tune.search.variant_generator import parse_spec_vars
28
+ from ray.tune.utils.util import flatten_dict, unflatten_dict, validate_warmstart
29
+
30
+ try:
31
+ import optuna as ot
32
+ from optuna.distributions import BaseDistribution as OptunaDistribution
33
+ from optuna.samplers import BaseSampler
34
+ from optuna.storages import BaseStorage
35
+ from optuna.trial import Trial as OptunaTrial
36
+ from optuna.trial import TrialState as OptunaTrialState
37
+ except ImportError:
38
+ ot = None
39
+ OptunaDistribution = None
40
+ BaseSampler = None
41
+ BaseStorage = None
42
+ OptunaTrialState = None
43
+ OptunaTrial = None
44
+
45
+ logger = logging.getLogger(__name__)
46
+
47
+ # print a warning if define by run function takes longer than this to execute
48
+ DEFINE_BY_RUN_WARN_THRESHOLD_S = 1 # 1 is arbitrary
49
+
50
+
51
+ class _OptunaTrialSuggestCaptor:
52
+ """Utility to capture returned values from Optuna's suggest_ methods.
53
+
54
+ This will wrap around the ``optuna.Trial` object and decorate all
55
+ `suggest_` callables with a function capturing the returned value,
56
+ which will be saved in the ``captured_values`` dict.
57
+ """
58
+
59
+ def __init__(self, ot_trial: OptunaTrial) -> None:
60
+ self.ot_trial = ot_trial
61
+ self.captured_values: Dict[str, Any] = {}
62
+
63
+ def _get_wrapper(self, func: Callable) -> Callable:
64
+ @functools.wraps(func)
65
+ def wrapper(*args, **kwargs):
66
+ # name is always the first arg for suggest_ methods
67
+ name = kwargs.get("name", args[0])
68
+ ret = func(*args, **kwargs)
69
+ self.captured_values[name] = ret
70
+ return ret
71
+
72
+ return wrapper
73
+
74
+ def __getattr__(self, item_name: str) -> Any:
75
+ item = getattr(self.ot_trial, item_name)
76
+ if item_name.startswith("suggest_") and callable(item):
77
+ return self._get_wrapper(item)
78
+ return item
79
+
80
+
81
+ class OptunaSearch(Searcher):
82
+ """A wrapper around Optuna to provide trial suggestions.
83
+
84
+ `Optuna <https://optuna.org/>`_ is a hyperparameter optimization library.
85
+ In contrast to other libraries, it employs define-by-run style
86
+ hyperparameter definitions.
87
+
88
+ This Searcher is a thin wrapper around Optuna's search algorithms.
89
+ You can pass any Optuna sampler, which will be used to generate
90
+ hyperparameter suggestions.
91
+
92
+ Multi-objective optimization is supported.
93
+
94
+ Args:
95
+ space: Hyperparameter search space definition for
96
+ Optuna's sampler. This can be either a :class:`dict` with
97
+ parameter names as keys and ``optuna.distributions`` as values,
98
+ or a Callable - in which case, it should be a define-by-run
99
+ function using ``optuna.trial`` to obtain the hyperparameter
100
+ values. The function should return either a :class:`dict` of
101
+ constant values with names as keys, or None.
102
+ For more information, see https://optuna.readthedocs.io\
103
+ /en/stable/tutorial/10_key_features/002_configurations.html.
104
+
105
+ .. warning::
106
+ No actual computation should take place in the define-by-run
107
+ function. Instead, put the training logic inside the function
108
+ or class trainable passed to ``tune.Tuner()``.
109
+
110
+ metric: The training result objective value attribute. If
111
+ None but a mode was passed, the anonymous metric ``_metric``
112
+ will be used per default. Can be a list of metrics for
113
+ multi-objective optimization.
114
+ mode: One of {min, max}. Determines whether objective is
115
+ minimizing or maximizing the metric attribute. Can be a list of
116
+ modes for multi-objective optimization (corresponding to
117
+ ``metric``).
118
+ points_to_evaluate: Initial parameter suggestions to be run
119
+ first. This is for when you already have some good parameters
120
+ you want to run first to help the algorithm make better suggestions
121
+ for future parameters. Needs to be a list of dicts containing the
122
+ configurations.
123
+ sampler: Optuna sampler used to
124
+ draw hyperparameter configurations. Defaults to ``MOTPESampler``
125
+ for multi-objective optimization with Optuna<2.9.0, and
126
+ ``TPESampler`` in every other case.
127
+ See https://optuna.readthedocs.io/en/stable/reference/samplers/index.html
128
+ for available Optuna samplers.
129
+
130
+ .. warning::
131
+ Please note that with Optuna 2.10.0 and earlier
132
+ default ``MOTPESampler``/``TPESampler`` suffer
133
+ from performance issues when dealing with a large number of
134
+ completed trials (approx. >100). This will manifest as
135
+ a delay when suggesting new configurations.
136
+ This is an Optuna issue and may be fixed in a future
137
+ Optuna release.
138
+ study_name: Optuna study name that uniquely identifies the trial
139
+ results. Defaults to ``"optuna"``.
140
+ storage: Optuna storage used for storing trial results to
141
+ storages other than in-memory storage,
142
+ for instance optuna.storages.RDBStorage.
143
+ seed: Seed to initialize sampler with. This parameter is only
144
+ used when ``sampler=None``. In all other cases, the sampler
145
+ you pass should be initialized with the seed already.
146
+ evaluated_rewards: If you have previously evaluated the
147
+ parameters passed in as points_to_evaluate you can avoid
148
+ re-running those trials by passing in the reward attributes
149
+ as a list so the optimiser can be told the results without
150
+ needing to re-compute the trial. Must be the same length as
151
+ points_to_evaluate.
152
+
153
+ .. warning::
154
+ When using ``evaluated_rewards``, the search space ``space``
155
+ must be provided as a :class:`dict` with parameter names as
156
+ keys and ``optuna.distributions`` instances as values. The
157
+ define-by-run search space definition is not yet supported with
158
+ this functionality.
159
+
160
+ Tune automatically converts search spaces to Optuna's format:
161
+
162
+ .. code-block:: python
163
+
164
+ from ray.tune.search.optuna import OptunaSearch
165
+
166
+ config = {
167
+ "a": tune.uniform(6, 8)
168
+ "b": tune.loguniform(1e-4, 1e-2)
169
+ }
170
+
171
+ optuna_search = OptunaSearch(
172
+ metric="loss",
173
+ mode="min")
174
+
175
+ tuner = tune.Tuner(
176
+ trainable,
177
+ tune_config=tune.TuneConfig(
178
+ search_alg=optuna_search,
179
+ ),
180
+ param_space=config,
181
+ )
182
+ tuner.fit()
183
+
184
+ If you would like to pass the search space manually, the code would
185
+ look like this:
186
+
187
+ .. code-block:: python
188
+
189
+ from ray.tune.search.optuna import OptunaSearch
190
+ import optuna
191
+
192
+ space = {
193
+ "a": optuna.distributions.FloatDistribution(6, 8),
194
+ "b": optuna.distributions.FloatDistribution(1e-4, 1e-2, log=True),
195
+ }
196
+
197
+ optuna_search = OptunaSearch(
198
+ space,
199
+ metric="loss",
200
+ mode="min")
201
+
202
+ tuner = tune.Tuner(
203
+ trainable,
204
+ tune_config=tune.TuneConfig(
205
+ search_alg=optuna_search,
206
+ ),
207
+ )
208
+ tuner.fit()
209
+
210
+ # Equivalent Optuna define-by-run function approach:
211
+
212
+ def define_search_space(trial: optuna.Trial):
213
+ trial.suggest_float("a", 6, 8)
214
+ trial.suggest_float("b", 1e-4, 1e-2, log=True)
215
+ # training logic goes into trainable, this is just
216
+ # for search space definition
217
+
218
+ optuna_search = OptunaSearch(
219
+ define_search_space,
220
+ metric="loss",
221
+ mode="min")
222
+
223
+ tuner = tune.Tuner(
224
+ trainable,
225
+ tune_config=tune.TuneConfig(
226
+ search_alg=optuna_search,
227
+ ),
228
+ )
229
+ tuner.fit()
230
+
231
+ Multi-objective optimization is supported:
232
+
233
+ .. code-block:: python
234
+
235
+ from ray.tune.search.optuna import OptunaSearch
236
+ import optuna
237
+
238
+ space = {
239
+ "a": optuna.distributions.FloatDistribution(6, 8),
240
+ "b": optuna.distributions.FloatDistribution(1e-4, 1e-2, log=True),
241
+ }
242
+
243
+ # Note you have to specify metric and mode here instead of
244
+ # in tune.TuneConfig
245
+ optuna_search = OptunaSearch(
246
+ space,
247
+ metric=["loss1", "loss2"],
248
+ mode=["min", "max"])
249
+
250
+ # Do not specify metric and mode here!
251
+ tuner = tune.Tuner(
252
+ trainable,
253
+ tune_config=tune.TuneConfig(
254
+ search_alg=optuna_search,
255
+ ),
256
+ )
257
+ tuner.fit()
258
+
259
+ You can pass configs that will be evaluated first using
260
+ ``points_to_evaluate``:
261
+
262
+ .. code-block:: python
263
+
264
+ from ray.tune.search.optuna import OptunaSearch
265
+ import optuna
266
+
267
+ space = {
268
+ "a": optuna.distributions.FloatDistribution(6, 8),
269
+ "b": optuna.distributions.FloatDistribution(1e-4, 1e-2, log=True),
270
+ }
271
+
272
+ optuna_search = OptunaSearch(
273
+ space,
274
+ points_to_evaluate=[{"a": 6.5, "b": 5e-4}, {"a": 7.5, "b": 1e-3}]
275
+ metric="loss",
276
+ mode="min")
277
+
278
+ tuner = tune.Tuner(
279
+ trainable,
280
+ tune_config=tune.TuneConfig(
281
+ search_alg=optuna_search,
282
+ ),
283
+ )
284
+ tuner.fit()
285
+
286
+ Avoid re-running evaluated trials by passing the rewards together with
287
+ `points_to_evaluate`:
288
+
289
+ .. code-block:: python
290
+
291
+ from ray.tune.search.optuna import OptunaSearch
292
+ import optuna
293
+
294
+ space = {
295
+ "a": optuna.distributions.FloatDistribution(6, 8),
296
+ "b": optuna.distributions.FloatDistribution(1e-4, 1e-2, log=True),
297
+ }
298
+
299
+ optuna_search = OptunaSearch(
300
+ space,
301
+ points_to_evaluate=[{"a": 6.5, "b": 5e-4}, {"a": 7.5, "b": 1e-3}]
302
+ evaluated_rewards=[0.89, 0.42]
303
+ metric="loss",
304
+ mode="min")
305
+
306
+ tuner = tune.Tuner(
307
+ trainable,
308
+ tune_config=tune.TuneConfig(
309
+ search_alg=optuna_search,
310
+ ),
311
+ )
312
+ tuner.fit()
313
+
314
+ .. versionadded:: 0.8.8
315
+
316
+ """
317
+
318
+ def __init__(
319
+ self,
320
+ space: Optional[
321
+ Union[
322
+ Dict[str, "OptunaDistribution"],
323
+ List[Tuple],
324
+ Callable[["OptunaTrial"], Optional[Dict[str, Any]]],
325
+ ]
326
+ ] = None,
327
+ metric: Optional[Union[str, List[str]]] = None,
328
+ mode: Optional[Union[str, List[str]]] = None,
329
+ points_to_evaluate: Optional[List[Dict]] = None,
330
+ sampler: Optional["BaseSampler"] = None,
331
+ study_name: Optional[str] = None,
332
+ storage: Optional["BaseStorage"] = None,
333
+ seed: Optional[int] = None,
334
+ evaluated_rewards: Optional[List] = None,
335
+ ):
336
+ assert ot is not None, "Optuna must be installed! Run `pip install optuna`."
337
+ super(OptunaSearch, self).__init__(metric=metric, mode=mode)
338
+
339
+ if isinstance(space, dict) and space:
340
+ resolved_vars, domain_vars, grid_vars = parse_spec_vars(space)
341
+ if domain_vars or grid_vars:
342
+ logger.warning(
343
+ UNRESOLVED_SEARCH_SPACE.format(par="space", cls=type(self).__name__)
344
+ )
345
+ space = self.convert_search_space(space)
346
+ else:
347
+ # Flatten to support nested dicts
348
+ space = flatten_dict(space, "/")
349
+
350
+ self._space = space
351
+
352
+ self._points_to_evaluate = points_to_evaluate or []
353
+ self._evaluated_rewards = evaluated_rewards
354
+ if study_name:
355
+ self._study_name = study_name
356
+ else:
357
+ self._study_name = "optuna" # Fixed study name for in-memory storage
358
+
359
+ if sampler and seed:
360
+ logger.warning(
361
+ "You passed an initialized sampler to `OptunaSearch`. The "
362
+ "`seed` parameter has to be passed to the sampler directly "
363
+ "and will be ignored."
364
+ )
365
+ elif sampler:
366
+ assert isinstance(sampler, BaseSampler), (
367
+ "You can only pass an instance of "
368
+ "`optuna.samplers.BaseSampler` "
369
+ "as a sampler to `OptunaSearcher`."
370
+ )
371
+
372
+ self._sampler = sampler
373
+ self._seed = seed
374
+
375
+ if storage:
376
+ assert isinstance(storage, BaseStorage), (
377
+ "The `storage` parameter in `OptunaSearcher` must be an instance "
378
+ "of `optuna.storages.BaseStorage`."
379
+ )
380
+ # If storage is not provided, just set self._storage to None
381
+ # so that the default in-memory storage is used.
382
+ self._storage = storage
383
+
384
+ self._completed_trials = set()
385
+
386
+ self._ot_trials = {}
387
+ self._ot_study = None
388
+ if self._space:
389
+ self._setup_study(mode)
390
+
391
+ def _setup_study(self, mode: Union[str, list]):
392
+ if self._metric is None and self._mode:
393
+ if isinstance(self._mode, list):
394
+ raise ValueError(
395
+ "If ``mode`` is a list (multi-objective optimization "
396
+ "case), ``metric`` must be defined."
397
+ )
398
+ # If only a mode was passed, use anonymous metric
399
+ self._metric = DEFAULT_METRIC
400
+
401
+ pruner = ot.pruners.NopPruner()
402
+
403
+ if self._sampler:
404
+ sampler = self._sampler
405
+ elif isinstance(mode, list) and version.parse(ot.__version__) < version.parse(
406
+ "2.9.0"
407
+ ):
408
+ # MOTPESampler deprecated in Optuna>=2.9.0
409
+ sampler = ot.samplers.MOTPESampler(seed=self._seed)
410
+ else:
411
+ sampler = ot.samplers.TPESampler(seed=self._seed)
412
+
413
+ if isinstance(mode, list):
414
+ study_direction_args = dict(
415
+ directions=["minimize" if m == "min" else "maximize" for m in mode],
416
+ )
417
+ else:
418
+ study_direction_args = dict(
419
+ direction="minimize" if mode == "min" else "maximize",
420
+ )
421
+
422
+ self._ot_study = ot.study.create_study(
423
+ storage=self._storage,
424
+ sampler=sampler,
425
+ pruner=pruner,
426
+ study_name=self._study_name,
427
+ load_if_exists=True,
428
+ **study_direction_args,
429
+ )
430
+
431
+ if self._points_to_evaluate:
432
+ validate_warmstart(
433
+ self._space,
434
+ self._points_to_evaluate,
435
+ self._evaluated_rewards,
436
+ validate_point_name_lengths=not callable(self._space),
437
+ )
438
+ if self._evaluated_rewards:
439
+ for point, reward in zip(
440
+ self._points_to_evaluate, self._evaluated_rewards
441
+ ):
442
+ self.add_evaluated_point(point, reward)
443
+ else:
444
+ for point in self._points_to_evaluate:
445
+ self._ot_study.enqueue_trial(point)
446
+
447
+ def set_search_properties(
448
+ self, metric: Optional[str], mode: Optional[str], config: Dict, **spec
449
+ ) -> bool:
450
+ if self._space:
451
+ return False
452
+ space = self.convert_search_space(config)
453
+ self._space = space
454
+ if metric:
455
+ self._metric = metric
456
+ if mode:
457
+ self._mode = mode
458
+
459
+ self._setup_study(self._mode)
460
+ return True
461
+
462
+ def _suggest_from_define_by_run_func(
463
+ self,
464
+ func: Callable[["OptunaTrial"], Optional[Dict[str, Any]]],
465
+ ot_trial: "OptunaTrial",
466
+ ) -> Dict:
467
+ captor = _OptunaTrialSuggestCaptor(ot_trial)
468
+ time_start = time.time()
469
+ ret = func(captor)
470
+ time_taken = time.time() - time_start
471
+ if time_taken > DEFINE_BY_RUN_WARN_THRESHOLD_S:
472
+ warnings.warn(
473
+ "Define-by-run function passed in the `space` argument "
474
+ f"took {time_taken} seconds to "
475
+ "run. Ensure that actual computation, training takes "
476
+ "place inside Tune's train functions or Trainables "
477
+ "passed to `tune.Tuner()`."
478
+ )
479
+ if ret is not None:
480
+ if not isinstance(ret, dict):
481
+ raise TypeError(
482
+ "The return value of the define-by-run function "
483
+ "passed in the `space` argument should be "
484
+ "either None or a `dict` with `str` keys. "
485
+ f"Got {type(ret)}."
486
+ )
487
+ if not all(isinstance(k, str) for k in ret.keys()):
488
+ raise TypeError(
489
+ "At least one of the keys in the dict returned by the "
490
+ "define-by-run function passed in the `space` argument "
491
+ "was not a `str`."
492
+ )
493
+ return {**captor.captured_values, **ret} if ret else captor.captured_values
494
+
495
+ def suggest(self, trial_id: str) -> Optional[Dict]:
496
+ if not self._space:
497
+ raise RuntimeError(
498
+ UNDEFINED_SEARCH_SPACE.format(
499
+ cls=self.__class__.__name__, space="space"
500
+ )
501
+ )
502
+ if not self._metric or not self._mode:
503
+ raise RuntimeError(
504
+ UNDEFINED_METRIC_MODE.format(
505
+ cls=self.__class__.__name__, metric=self._metric, mode=self._mode
506
+ )
507
+ )
508
+ if callable(self._space):
509
+ # Define-by-run case
510
+ if trial_id not in self._ot_trials:
511
+ self._ot_trials[trial_id] = self._ot_study.ask()
512
+
513
+ ot_trial = self._ot_trials[trial_id]
514
+
515
+ params = self._suggest_from_define_by_run_func(self._space, ot_trial)
516
+ else:
517
+ # Use Optuna ask interface (since version 2.6.0)
518
+ if trial_id not in self._ot_trials:
519
+ self._ot_trials[trial_id] = self._ot_study.ask(
520
+ fixed_distributions=self._space
521
+ )
522
+ ot_trial = self._ot_trials[trial_id]
523
+ params = ot_trial.params
524
+
525
+ return unflatten_dict(params)
526
+
527
+ def on_trial_result(self, trial_id: str, result: Dict):
528
+ if isinstance(self.metric, list):
529
+ # Optuna doesn't support incremental results
530
+ # for multi-objective optimization
531
+ return
532
+ if trial_id in self._completed_trials:
533
+ logger.warning(
534
+ f"Received additional result for trial {trial_id}, but "
535
+ f"it already finished. Result: {result}"
536
+ )
537
+ return
538
+ metric = result[self.metric]
539
+ step = result[TRAINING_ITERATION]
540
+ ot_trial = self._ot_trials[trial_id]
541
+ ot_trial.report(metric, step)
542
+
543
+ def on_trial_complete(
544
+ self, trial_id: str, result: Optional[Dict] = None, error: bool = False
545
+ ):
546
+ if trial_id in self._completed_trials:
547
+ logger.warning(
548
+ f"Received additional completion for trial {trial_id}, but "
549
+ f"it already finished. Result: {result}"
550
+ )
551
+ return
552
+
553
+ ot_trial = self._ot_trials[trial_id]
554
+
555
+ if result:
556
+ if isinstance(self.metric, list):
557
+ val = [result.get(metric, None) for metric in self.metric]
558
+ else:
559
+ val = result.get(self.metric, None)
560
+ else:
561
+ val = None
562
+ ot_trial_state = OptunaTrialState.COMPLETE
563
+ if val is None:
564
+ if error:
565
+ ot_trial_state = OptunaTrialState.FAIL
566
+ else:
567
+ ot_trial_state = OptunaTrialState.PRUNED
568
+ try:
569
+ self._ot_study.tell(ot_trial, val, state=ot_trial_state)
570
+ except Exception as exc:
571
+ logger.warning(exc) # E.g. if NaN was reported
572
+
573
+ self._completed_trials.add(trial_id)
574
+
575
+ def add_evaluated_point(
576
+ self,
577
+ parameters: Dict,
578
+ value: float,
579
+ error: bool = False,
580
+ pruned: bool = False,
581
+ intermediate_values: Optional[List[float]] = None,
582
+ ):
583
+ if not self._space:
584
+ raise RuntimeError(
585
+ UNDEFINED_SEARCH_SPACE.format(
586
+ cls=self.__class__.__name__, space="space"
587
+ )
588
+ )
589
+ if not self._metric or not self._mode:
590
+ raise RuntimeError(
591
+ UNDEFINED_METRIC_MODE.format(
592
+ cls=self.__class__.__name__, metric=self._metric, mode=self._mode
593
+ )
594
+ )
595
+ if callable(self._space):
596
+ raise TypeError(
597
+ "Define-by-run function passed in `space` argument is not "
598
+ "yet supported when using `evaluated_rewards`. Please provide "
599
+ "an `OptunaDistribution` dict or pass a Ray Tune "
600
+ "search space to `tune.Tuner()`."
601
+ )
602
+
603
+ ot_trial_state = OptunaTrialState.COMPLETE
604
+ if error:
605
+ ot_trial_state = OptunaTrialState.FAIL
606
+ elif pruned:
607
+ ot_trial_state = OptunaTrialState.PRUNED
608
+
609
+ if intermediate_values:
610
+ intermediate_values_dict = {
611
+ i: value for i, value in enumerate(intermediate_values)
612
+ }
613
+ else:
614
+ intermediate_values_dict = None
615
+
616
+ # If the trial state is FAILED, the value must be `None` in Optuna==4.1.0
617
+ # Reference: https://github.com/optuna/optuna/pull/5211
618
+ # This is a temporary fix for the issue that Optuna enforces the value
619
+ # to be `None` if the trial state is FAILED.
620
+ # TODO (hpguo): A better solution may requires us to update the base class
621
+ # to allow the `value` arg in `add_evaluated_point` being `Optional[float]`.
622
+ if ot_trial_state == OptunaTrialState.FAIL:
623
+ value = None
624
+
625
+ trial = ot.trial.create_trial(
626
+ state=ot_trial_state,
627
+ value=value,
628
+ params=parameters,
629
+ distributions=self._space,
630
+ intermediate_values=intermediate_values_dict,
631
+ )
632
+
633
+ self._ot_study.add_trial(trial)
634
+
635
+ def save(self, checkpoint_path: str):
636
+ save_object = self.__dict__.copy()
637
+ with open(checkpoint_path, "wb") as outputFile:
638
+ pickle.dump(save_object, outputFile)
639
+
640
+ def restore(self, checkpoint_path: str):
641
+ with open(checkpoint_path, "rb") as inputFile:
642
+ save_object = pickle.load(inputFile)
643
+ if isinstance(save_object, dict):
644
+ self.__dict__.update(save_object)
645
+ else:
646
+ # Backwards compatibility
647
+ (
648
+ self._sampler,
649
+ self._ot_trials,
650
+ self._ot_study,
651
+ self._points_to_evaluate,
652
+ self._evaluated_rewards,
653
+ ) = save_object
654
+
655
+ @staticmethod
656
+ def convert_search_space(spec: Dict) -> Dict[str, Any]:
657
+ resolved_vars, domain_vars, grid_vars = parse_spec_vars(spec)
658
+
659
+ if not domain_vars and not grid_vars:
660
+ return {}
661
+
662
+ if grid_vars:
663
+ raise ValueError(
664
+ "Grid search parameters cannot be automatically converted "
665
+ "to an Optuna search space."
666
+ )
667
+
668
+ # Flatten and resolve again after checking for grid search.
669
+ spec = flatten_dict(spec, prevent_delimiter=True)
670
+ resolved_vars, domain_vars, grid_vars = parse_spec_vars(spec)
671
+
672
+ def resolve_value(domain: Domain) -> ot.distributions.BaseDistribution:
673
+ quantize = None
674
+
675
+ sampler = domain.get_sampler()
676
+ if isinstance(sampler, Quantized):
677
+ quantize = sampler.q
678
+ sampler = sampler.sampler
679
+ if isinstance(sampler, LogUniform):
680
+ logger.warning(
681
+ "Optuna does not handle quantization in loguniform "
682
+ "sampling. The parameter will be passed but it will "
683
+ "probably be ignored."
684
+ )
685
+
686
+ if isinstance(domain, Float):
687
+ if isinstance(sampler, LogUniform):
688
+ if quantize:
689
+ logger.warning(
690
+ "Optuna does not support both quantization and "
691
+ "sampling from LogUniform. Dropped quantization."
692
+ )
693
+ return ot.distributions.FloatDistribution(
694
+ domain.lower, domain.upper, log=True
695
+ )
696
+
697
+ elif isinstance(sampler, Uniform):
698
+ if quantize:
699
+ return ot.distributions.FloatDistribution(
700
+ domain.lower, domain.upper, step=quantize
701
+ )
702
+ return ot.distributions.FloatDistribution(
703
+ domain.lower, domain.upper
704
+ )
705
+
706
+ elif isinstance(domain, Integer):
707
+ if isinstance(sampler, LogUniform):
708
+ return ot.distributions.IntDistribution(
709
+ domain.lower, domain.upper - 1, step=quantize or 1, log=True
710
+ )
711
+ elif isinstance(sampler, Uniform):
712
+ # Upper bound should be inclusive for quantization and
713
+ # exclusive otherwise
714
+ return ot.distributions.IntDistribution(
715
+ domain.lower,
716
+ domain.upper - int(bool(not quantize)),
717
+ step=quantize or 1,
718
+ )
719
+ elif isinstance(domain, Categorical):
720
+ if isinstance(sampler, Uniform):
721
+ return ot.distributions.CategoricalDistribution(domain.categories)
722
+
723
+ raise ValueError(
724
+ "Optuna search does not support parameters of type "
725
+ "`{}` with samplers of type `{}`".format(
726
+ type(domain).__name__, type(domain.sampler).__name__
727
+ )
728
+ )
729
+
730
+ # Parameter name is e.g. "a/b/c" for nested dicts
731
+ values = {"/".join(path): resolve_value(domain) for path, domain in domain_vars}
732
+
733
+ return values
.venv/lib/python3.11/site-packages/ray/tune/search/repeater.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import logging
3
+ from typing import Dict, List, Optional
4
+
5
+ import numpy as np
6
+
7
+ from ray.tune.search.searcher import Searcher
8
+ from ray.tune.search.util import _set_search_properties_backwards_compatible
9
+ from ray.util import PublicAPI
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+ TRIAL_INDEX = "__trial_index__"
14
+ """str: A constant value representing the repeat index of the trial."""
15
+
16
+
17
+ def _warn_num_samples(searcher: Searcher, num_samples: int):
18
+ if isinstance(searcher, Repeater) and num_samples % searcher.repeat:
19
+ logger.warning(
20
+ "`num_samples` is now expected to be the total number of trials, "
21
+ "including the repeat trials. For example, set num_samples=15 if "
22
+ "you intend to obtain 3 search algorithm suggestions and repeat "
23
+ "each suggestion 5 times. Any leftover trials "
24
+ "(num_samples mod repeat) will be ignored."
25
+ )
26
+
27
+
28
+ class _TrialGroup:
29
+ """Internal class for grouping trials of same parameters.
30
+
31
+ This is used when repeating trials for reducing training variance.
32
+
33
+ Args:
34
+ primary_trial_id: Trial ID of the "primary trial".
35
+ This trial is the one that the Searcher is aware of.
36
+ config: Suggested configuration shared across all trials
37
+ in the trial group.
38
+ max_trials: Max number of trials to execute within this group.
39
+
40
+ """
41
+
42
+ def __init__(self, primary_trial_id: str, config: Dict, max_trials: int = 1):
43
+ assert type(config) is dict, "config is not a dict, got {}".format(config)
44
+ self.primary_trial_id = primary_trial_id
45
+ self.config = config
46
+ self._trials = {primary_trial_id: None}
47
+ self.max_trials = max_trials
48
+
49
+ def add(self, trial_id: str):
50
+ assert len(self._trials) < self.max_trials
51
+ self._trials.setdefault(trial_id, None)
52
+
53
+ def full(self) -> bool:
54
+ return len(self._trials) == self.max_trials
55
+
56
+ def report(self, trial_id: str, score: float):
57
+ assert trial_id in self._trials
58
+ if score is None:
59
+ raise ValueError("Internal Error: Score cannot be None.")
60
+ self._trials[trial_id] = score
61
+
62
+ def finished_reporting(self) -> bool:
63
+ return (
64
+ None not in self._trials.values() and len(self._trials) == self.max_trials
65
+ )
66
+
67
+ def scores(self) -> List[Optional[float]]:
68
+ return list(self._trials.values())
69
+
70
+ def count(self) -> int:
71
+ return len(self._trials)
72
+
73
+
74
+ @PublicAPI
75
+ class Repeater(Searcher):
76
+ """A wrapper algorithm for repeating trials of same parameters.
77
+
78
+ Set tune.TuneConfig(num_samples=...) to be a multiple of `repeat`. For example,
79
+ set num_samples=15 if you intend to obtain 3 search algorithm suggestions
80
+ and repeat each suggestion 5 times. Any leftover trials
81
+ (num_samples mod repeat) will be ignored.
82
+
83
+ It is recommended that you do not run an early-stopping TrialScheduler
84
+ simultaneously.
85
+
86
+ Args:
87
+ searcher: Searcher object that the
88
+ Repeater will optimize. Note that the Searcher
89
+ will only see 1 trial among multiple repeated trials.
90
+ The result/metric passed to the Searcher upon
91
+ trial completion will be averaged among all repeats.
92
+ repeat: Number of times to generate a trial with a repeated
93
+ configuration. Defaults to 1.
94
+ set_index: Sets a tune.search.repeater.TRIAL_INDEX in
95
+ Trainable/Function config which corresponds to the index of the
96
+ repeated trial. This can be used for seeds. Defaults to True.
97
+
98
+ Example:
99
+
100
+ .. code-block:: python
101
+
102
+ from ray.tune.search import Repeater
103
+
104
+ search_alg = BayesOptSearch(...)
105
+ re_search_alg = Repeater(search_alg, repeat=10)
106
+
107
+ # Repeat 2 samples 10 times each.
108
+ tuner = tune.Tuner(
109
+ trainable,
110
+ tune_config=tune.TuneConfig(
111
+ search_alg=re_search_alg,
112
+ num_samples=20,
113
+ ),
114
+ )
115
+ tuner.fit()
116
+
117
+ """
118
+
119
+ def __init__(self, searcher: Searcher, repeat: int = 1, set_index: bool = True):
120
+ self.searcher = searcher
121
+ self.repeat = repeat
122
+ self._set_index = set_index
123
+ self._groups = []
124
+ self._trial_id_to_group = {}
125
+ self._current_group = None
126
+ super(Repeater, self).__init__(
127
+ metric=self.searcher.metric, mode=self.searcher.mode
128
+ )
129
+
130
+ def suggest(self, trial_id: str) -> Optional[Dict]:
131
+ if self._current_group is None or self._current_group.full():
132
+ config = self.searcher.suggest(trial_id)
133
+ if config is None:
134
+ return config
135
+ self._current_group = _TrialGroup(
136
+ trial_id, copy.deepcopy(config), max_trials=self.repeat
137
+ )
138
+ self._groups.append(self._current_group)
139
+ index_in_group = 0
140
+ else:
141
+ index_in_group = self._current_group.count()
142
+ self._current_group.add(trial_id)
143
+
144
+ config = self._current_group.config.copy()
145
+ if self._set_index:
146
+ config[TRIAL_INDEX] = index_in_group
147
+ self._trial_id_to_group[trial_id] = self._current_group
148
+ return config
149
+
150
+ def on_trial_complete(self, trial_id: str, result: Optional[Dict] = None, **kwargs):
151
+ """Stores the score for and keeps track of a completed trial.
152
+
153
+ Stores the metric of a trial as nan if any of the following conditions
154
+ are met:
155
+
156
+ 1. ``result`` is empty or not provided.
157
+ 2. ``result`` is provided but no metric was provided.
158
+
159
+ """
160
+ if trial_id not in self._trial_id_to_group:
161
+ logger.error(
162
+ "Trial {} not in group; cannot report score. "
163
+ "Seen trials: {}".format(trial_id, list(self._trial_id_to_group))
164
+ )
165
+ trial_group = self._trial_id_to_group[trial_id]
166
+ if not result or self.searcher.metric not in result:
167
+ score = np.nan
168
+ else:
169
+ score = result[self.searcher.metric]
170
+ trial_group.report(trial_id, score)
171
+
172
+ if trial_group.finished_reporting():
173
+ scores = trial_group.scores()
174
+ self.searcher.on_trial_complete(
175
+ trial_group.primary_trial_id,
176
+ result={self.searcher.metric: np.nanmean(scores)},
177
+ **kwargs
178
+ )
179
+
180
+ def get_state(self) -> Dict:
181
+ self_state = self.__dict__.copy()
182
+ del self_state["searcher"]
183
+ return self_state
184
+
185
+ def set_state(self, state: Dict):
186
+ self.__dict__.update(state)
187
+
188
+ def save(self, checkpoint_path: str):
189
+ self.searcher.save(checkpoint_path)
190
+
191
+ def restore(self, checkpoint_path: str):
192
+ self.searcher.restore(checkpoint_path)
193
+
194
+ def set_search_properties(
195
+ self, metric: Optional[str], mode: Optional[str], config: Dict, **spec
196
+ ) -> bool:
197
+ return _set_search_properties_backwards_compatible(
198
+ self.searcher.set_search_properties, metric, mode, config, **spec
199
+ )
.venv/lib/python3.11/site-packages/ray/tune/search/sample.py ADDED
@@ -0,0 +1,742 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from copy import copy
3
+ from inspect import signature
4
+ from math import isclose
5
+ from typing import Any, Callable, Dict, List, Optional, Sequence, Union
6
+
7
+ import numpy as np
8
+
9
+ # Backwards compatibility
10
+ from ray.util.annotations import DeveloperAPI, PublicAPI
11
+
12
+ try:
13
+ # Added in numpy>=1.17 but we require numpy>=1.16
14
+ np_random_generator = np.random.Generator
15
+ LEGACY_RNG = False
16
+ except AttributeError:
17
+
18
+ class np_random_generator:
19
+ pass
20
+
21
+ LEGACY_RNG = True
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ class _BackwardsCompatibleNumpyRng:
27
+ """Thin wrapper to ensure backwards compatibility between
28
+ new and old numpy randomness generators.
29
+ """
30
+
31
+ _rng = None
32
+
33
+ def __init__(
34
+ self,
35
+ generator_or_seed: Optional[
36
+ Union["np_random_generator", np.random.RandomState, int]
37
+ ] = None,
38
+ ):
39
+ if generator_or_seed is None or isinstance(
40
+ generator_or_seed, (np.random.RandomState, np_random_generator)
41
+ ):
42
+ self._rng = generator_or_seed
43
+ elif LEGACY_RNG:
44
+ self._rng = np.random.RandomState(generator_or_seed)
45
+ else:
46
+ self._rng = np.random.default_rng(generator_or_seed)
47
+
48
+ @property
49
+ def legacy_rng(self) -> bool:
50
+ return not isinstance(self._rng, np_random_generator)
51
+
52
+ @property
53
+ def rng(self):
54
+ # don't set self._rng to np.random to avoid picking issues
55
+ return self._rng if self._rng is not None else np.random
56
+
57
+ def __getattr__(self, name: str) -> Any:
58
+ # https://numpy.org/doc/stable/reference/random/new-or-different.html
59
+ if self.legacy_rng:
60
+ if name == "integers":
61
+ name = "randint"
62
+ elif name == "random":
63
+ name = "rand"
64
+ return getattr(self.rng, name)
65
+
66
+
67
+ RandomState = Union[
68
+ None, _BackwardsCompatibleNumpyRng, np_random_generator, np.random.RandomState, int
69
+ ]
70
+
71
+
72
+ @DeveloperAPI
73
+ class Domain:
74
+ """Base class to specify a type and valid range to sample parameters from.
75
+
76
+ This base class is implemented by parameter spaces, like float ranges
77
+ (``Float``), integer ranges (``Integer``), or categorical variables
78
+ (``Categorical``). The ``Domain`` object contains information about
79
+ valid values (e.g. minimum and maximum values), and exposes methods that
80
+ allow specification of specific samplers (e.g. ``uniform()`` or
81
+ ``loguniform()``).
82
+
83
+ """
84
+
85
+ sampler = None
86
+ default_sampler_cls = None
87
+
88
+ def cast(self, value):
89
+ """Cast value to domain type"""
90
+ return value
91
+
92
+ def set_sampler(self, sampler, allow_override=False):
93
+ if self.sampler and not allow_override:
94
+ raise ValueError(
95
+ "You can only choose one sampler for parameter "
96
+ "domains. Existing sampler for parameter {}: "
97
+ "{}. Tried to add {}".format(
98
+ self.__class__.__name__, self.sampler, sampler
99
+ )
100
+ )
101
+ self.sampler = sampler
102
+
103
+ def get_sampler(self):
104
+ sampler = self.sampler
105
+ if not sampler:
106
+ sampler = self.default_sampler_cls()
107
+ return sampler
108
+
109
+ def sample(
110
+ self,
111
+ config: Optional[Union[List[Dict], Dict]] = None,
112
+ size: int = 1,
113
+ random_state: "RandomState" = None,
114
+ ):
115
+ if not isinstance(random_state, _BackwardsCompatibleNumpyRng):
116
+ random_state = _BackwardsCompatibleNumpyRng(random_state)
117
+ sampler = self.get_sampler()
118
+ return sampler.sample(self, config=config, size=size, random_state=random_state)
119
+
120
+ def is_grid(self):
121
+ return isinstance(self.sampler, Grid)
122
+
123
+ def is_function(self):
124
+ return False
125
+
126
+ def is_valid(self, value: Any):
127
+ """Returns True if `value` is a valid value in this domain."""
128
+ raise NotImplementedError
129
+
130
+ @property
131
+ def domain_str(self):
132
+ return "(unknown)"
133
+
134
+
135
+ @DeveloperAPI
136
+ class Sampler:
137
+ def sample(
138
+ self,
139
+ domain: Domain,
140
+ config: Optional[Union[List[Dict], Dict]] = None,
141
+ size: int = 1,
142
+ random_state: "RandomState" = None,
143
+ ):
144
+ raise NotImplementedError
145
+
146
+
147
+ @DeveloperAPI
148
+ class BaseSampler(Sampler):
149
+ def __str__(self):
150
+ return "Base"
151
+
152
+
153
+ @DeveloperAPI
154
+ class Uniform(Sampler):
155
+ def __str__(self):
156
+ return "Uniform"
157
+
158
+
159
+ @DeveloperAPI
160
+ class LogUniform(Sampler):
161
+ def __init__(self, base: float = 10):
162
+ self.base = base
163
+ assert self.base > 0, "Base has to be strictly greater than 0"
164
+
165
+ def __str__(self):
166
+ return "LogUniform"
167
+
168
+
169
+ @DeveloperAPI
170
+ class Normal(Sampler):
171
+ def __init__(self, mean: float = 0.0, sd: float = 0.0):
172
+ self.mean = mean
173
+ self.sd = sd
174
+
175
+ assert self.sd > 0, "SD has to be strictly greater than 0"
176
+
177
+ def __str__(self):
178
+ return "Normal"
179
+
180
+
181
+ @DeveloperAPI
182
+ class Grid(Sampler):
183
+ """Dummy sampler used for grid search"""
184
+
185
+ def sample(
186
+ self,
187
+ domain: Domain,
188
+ config: Optional[Union[List[Dict], Dict]] = None,
189
+ size: int = 1,
190
+ random_state: "RandomState" = None,
191
+ ):
192
+ return RuntimeError("Do not call `sample()` on grid.")
193
+
194
+
195
+ @DeveloperAPI
196
+ class Float(Domain):
197
+ class _Uniform(Uniform):
198
+ def sample(
199
+ self,
200
+ domain: "Float",
201
+ config: Optional[Union[List[Dict], Dict]] = None,
202
+ size: int = 1,
203
+ random_state: "RandomState" = None,
204
+ ):
205
+ if not isinstance(random_state, _BackwardsCompatibleNumpyRng):
206
+ random_state = _BackwardsCompatibleNumpyRng(random_state)
207
+ assert domain.lower > float("-inf"), "Uniform needs a lower bound"
208
+ assert domain.upper < float("inf"), "Uniform needs a upper bound"
209
+ items = random_state.uniform(domain.lower, domain.upper, size=size)
210
+ return items if len(items) > 1 else domain.cast(items[0])
211
+
212
+ class _LogUniform(LogUniform):
213
+ def sample(
214
+ self,
215
+ domain: "Float",
216
+ config: Optional[Union[List[Dict], Dict]] = None,
217
+ size: int = 1,
218
+ random_state: "RandomState" = None,
219
+ ):
220
+ if not isinstance(random_state, _BackwardsCompatibleNumpyRng):
221
+ random_state = _BackwardsCompatibleNumpyRng(random_state)
222
+ assert domain.lower > 0, "LogUniform needs a lower bound greater than 0"
223
+ assert (
224
+ 0 < domain.upper < float("inf")
225
+ ), "LogUniform needs a upper bound greater than 0"
226
+ logmin = np.log(domain.lower) / np.log(self.base)
227
+ logmax = np.log(domain.upper) / np.log(self.base)
228
+
229
+ items = self.base ** (random_state.uniform(logmin, logmax, size=size))
230
+ return items if len(items) > 1 else domain.cast(items[0])
231
+
232
+ class _Normal(Normal):
233
+ def sample(
234
+ self,
235
+ domain: "Float",
236
+ config: Optional[Union[List[Dict], Dict]] = None,
237
+ size: int = 1,
238
+ random_state: "RandomState" = None,
239
+ ):
240
+ if not isinstance(random_state, _BackwardsCompatibleNumpyRng):
241
+ random_state = _BackwardsCompatibleNumpyRng(random_state)
242
+ assert not domain.lower or domain.lower == float(
243
+ "-inf"
244
+ ), "Normal sampling does not allow a lower value bound."
245
+ assert not domain.upper or domain.upper == float(
246
+ "inf"
247
+ ), "Normal sampling does not allow a upper value bound."
248
+ items = random_state.normal(self.mean, self.sd, size=size)
249
+ return items if len(items) > 1 else domain.cast(items[0])
250
+
251
+ default_sampler_cls = _Uniform
252
+
253
+ def __init__(self, lower: Optional[float], upper: Optional[float]):
254
+ # Need to explicitly check for None
255
+ self.lower = lower if lower is not None else float("-inf")
256
+ self.upper = upper if upper is not None else float("inf")
257
+
258
+ def cast(self, value):
259
+ return float(value)
260
+
261
+ def uniform(self):
262
+ if not self.lower > float("-inf"):
263
+ raise ValueError(
264
+ "Uniform requires a lower bound. Make sure to set the "
265
+ "`lower` parameter of `Float()`."
266
+ )
267
+ if not self.upper < float("inf"):
268
+ raise ValueError(
269
+ "Uniform requires a upper bound. Make sure to set the "
270
+ "`upper` parameter of `Float()`."
271
+ )
272
+ new = copy(self)
273
+ new.set_sampler(self._Uniform())
274
+ return new
275
+
276
+ def loguniform(self, base: float = 10):
277
+ if not self.lower > 0:
278
+ raise ValueError(
279
+ "LogUniform requires a lower bound greater than 0."
280
+ f"Got: {self.lower}. Did you pass a variable that has "
281
+ "been log-transformed? If so, pass the non-transformed value "
282
+ "instead."
283
+ )
284
+ if not 0 < self.upper < float("inf"):
285
+ raise ValueError(
286
+ "LogUniform requires a upper bound greater than 0. "
287
+ f"Got: {self.lower}. Did you pass a variable that has "
288
+ "been log-transformed? If so, pass the non-transformed value "
289
+ "instead."
290
+ )
291
+ new = copy(self)
292
+ new.set_sampler(self._LogUniform(base))
293
+ return new
294
+
295
+ def normal(self, mean=0.0, sd=1.0):
296
+ new = copy(self)
297
+ new.set_sampler(self._Normal(mean, sd))
298
+ return new
299
+
300
+ def quantized(self, q: float):
301
+ if self.lower > float("-inf") and not isclose(
302
+ self.lower / q, round(self.lower / q)
303
+ ):
304
+ raise ValueError(
305
+ f"Your lower variable bound {self.lower} is not divisible by "
306
+ f"quantization factor {q}."
307
+ )
308
+ if self.upper < float("inf") and not isclose(
309
+ self.upper / q, round(self.upper / q)
310
+ ):
311
+ raise ValueError(
312
+ f"Your upper variable bound {self.upper} is not divisible by "
313
+ f"quantization factor {q}."
314
+ )
315
+
316
+ new = copy(self)
317
+ new.set_sampler(Quantized(new.get_sampler(), q), allow_override=True)
318
+ return new
319
+
320
+ def is_valid(self, value: float):
321
+ return self.lower <= value <= self.upper
322
+
323
+ @property
324
+ def domain_str(self):
325
+ return f"({self.lower}, {self.upper})"
326
+
327
+
328
+ @DeveloperAPI
329
+ class Integer(Domain):
330
+ class _Uniform(Uniform):
331
+ def sample(
332
+ self,
333
+ domain: "Integer",
334
+ config: Optional[Union[List[Dict], Dict]] = None,
335
+ size: int = 1,
336
+ random_state: "RandomState" = None,
337
+ ):
338
+ if not isinstance(random_state, _BackwardsCompatibleNumpyRng):
339
+ random_state = _BackwardsCompatibleNumpyRng(random_state)
340
+ items = random_state.integers(domain.lower, domain.upper, size=size)
341
+ return items if len(items) > 1 else domain.cast(items[0])
342
+
343
+ class _LogUniform(LogUniform):
344
+ def sample(
345
+ self,
346
+ domain: "Integer",
347
+ config: Optional[Union[List[Dict], Dict]] = None,
348
+ size: int = 1,
349
+ random_state: "RandomState" = None,
350
+ ):
351
+ if not isinstance(random_state, _BackwardsCompatibleNumpyRng):
352
+ random_state = _BackwardsCompatibleNumpyRng(random_state)
353
+ assert domain.lower > 0, "LogUniform needs a lower bound greater than 0"
354
+ assert (
355
+ 0 < domain.upper < float("inf")
356
+ ), "LogUniform needs a upper bound greater than 0"
357
+ logmin = np.log(domain.lower) / np.log(self.base)
358
+ logmax = np.log(domain.upper) / np.log(self.base)
359
+
360
+ items = self.base ** (random_state.uniform(logmin, logmax, size=size))
361
+ items = np.floor(items).astype(int)
362
+ return items if len(items) > 1 else domain.cast(items[0])
363
+
364
+ default_sampler_cls = _Uniform
365
+
366
+ def __init__(self, lower, upper):
367
+ self.lower = lower
368
+ self.upper = upper
369
+
370
+ def cast(self, value):
371
+ return int(value)
372
+
373
+ def quantized(self, q: int):
374
+ new = copy(self)
375
+ new.set_sampler(Quantized(new.get_sampler(), q), allow_override=True)
376
+ return new
377
+
378
+ def uniform(self):
379
+ new = copy(self)
380
+ new.set_sampler(self._Uniform())
381
+ return new
382
+
383
+ def loguniform(self, base: float = 10):
384
+ if not self.lower > 0:
385
+ raise ValueError(
386
+ "LogUniform requires a lower bound greater than 0."
387
+ f"Got: {self.lower}. Did you pass a variable that has "
388
+ "been log-transformed? If so, pass the non-transformed value "
389
+ "instead."
390
+ )
391
+ if not 0 < self.upper < float("inf"):
392
+ raise ValueError(
393
+ "LogUniform requires a upper bound greater than 0. "
394
+ f"Got: {self.lower}. Did you pass a variable that has "
395
+ "been log-transformed? If so, pass the non-transformed value "
396
+ "instead."
397
+ )
398
+ new = copy(self)
399
+ new.set_sampler(self._LogUniform(base))
400
+ return new
401
+
402
+ def is_valid(self, value: int):
403
+ return self.lower <= value <= self.upper
404
+
405
+ @property
406
+ def domain_str(self):
407
+ return f"({self.lower}, {self.upper})"
408
+
409
+
410
+ @DeveloperAPI
411
+ class Categorical(Domain):
412
+ class _Uniform(Uniform):
413
+ def sample(
414
+ self,
415
+ domain: "Categorical",
416
+ config: Optional[Union[List[Dict], Dict]] = None,
417
+ size: int = 1,
418
+ random_state: "RandomState" = None,
419
+ ):
420
+ if not isinstance(random_state, _BackwardsCompatibleNumpyRng):
421
+ random_state = _BackwardsCompatibleNumpyRng(random_state)
422
+ # do not use .choice() directly on domain.categories
423
+ # as that will coerce them to a single dtype
424
+ indices = random_state.choice(
425
+ np.arange(0, len(domain.categories)), size=size
426
+ )
427
+ items = [domain.categories[index] for index in indices]
428
+ return items if len(items) > 1 else domain.cast(items[0])
429
+
430
+ default_sampler_cls = _Uniform
431
+
432
+ def __init__(self, categories: Sequence):
433
+ self.categories = list(categories)
434
+
435
+ def uniform(self):
436
+ new = copy(self)
437
+ new.set_sampler(self._Uniform())
438
+ return new
439
+
440
+ def grid(self):
441
+ new = copy(self)
442
+ new.set_sampler(Grid())
443
+ return new
444
+
445
+ def __len__(self):
446
+ return len(self.categories)
447
+
448
+ def __getitem__(self, item):
449
+ return self.categories[item]
450
+
451
+ def is_valid(self, value: Any):
452
+ return value in self.categories
453
+
454
+ @property
455
+ def domain_str(self):
456
+ return f"{self.categories}"
457
+
458
+
459
+ @DeveloperAPI
460
+ class Function(Domain):
461
+ class _CallSampler(BaseSampler):
462
+ def __try_fn(self, domain: "Function", config: Dict[str, Any]):
463
+ try:
464
+ return domain.func(config)
465
+ except (AttributeError, KeyError):
466
+ from ray.tune.search.variant_generator import _UnresolvedAccessGuard
467
+
468
+ r = domain.func(_UnresolvedAccessGuard({"config": config}))
469
+ logger.warning(
470
+ "sample_from functions that take a spec dict are "
471
+ "deprecated. Please update your function to work with "
472
+ "the config dict directly."
473
+ )
474
+ return r
475
+
476
+ def sample(
477
+ self,
478
+ domain: "Function",
479
+ config: Optional[Union[List[Dict], Dict]] = None,
480
+ size: int = 1,
481
+ random_state: "RandomState" = None,
482
+ ):
483
+ if not isinstance(random_state, _BackwardsCompatibleNumpyRng):
484
+ random_state = _BackwardsCompatibleNumpyRng(random_state)
485
+ if domain.pass_config:
486
+ items = [
487
+ self.__try_fn(domain, config[i])
488
+ if isinstance(config, list)
489
+ else self.__try_fn(domain, config)
490
+ for i in range(size)
491
+ ]
492
+ else:
493
+ items = [domain.func() for i in range(size)]
494
+
495
+ return items if len(items) > 1 else domain.cast(items[0])
496
+
497
+ default_sampler_cls = _CallSampler
498
+
499
+ def __init__(self, func: Callable):
500
+ sig = signature(func)
501
+
502
+ pass_config = True # whether we should pass `config` when calling `func`
503
+ try:
504
+ sig.bind({})
505
+ except TypeError:
506
+ pass_config = False
507
+
508
+ if not pass_config:
509
+ try:
510
+ sig.bind()
511
+ except TypeError as exc:
512
+ raise ValueError(
513
+ "The function passed to a `Function` parameter must be "
514
+ "callable with either 0 or 1 parameters."
515
+ ) from exc
516
+
517
+ self.pass_config = pass_config
518
+ self.func = func
519
+
520
+ def is_function(self):
521
+ return True
522
+
523
+ def is_valid(self, value: Any):
524
+ return True # This is user-defined, so lets not assume anything
525
+
526
+ @property
527
+ def domain_str(self):
528
+ return f"{self.func}()"
529
+
530
+
531
+ @DeveloperAPI
532
+ class Quantized(Sampler):
533
+ def __init__(self, sampler: Sampler, q: Union[float, int]):
534
+ self.sampler = sampler
535
+ self.q = q
536
+
537
+ assert self.sampler, "Quantized() expects a sampler instance"
538
+
539
+ def get_sampler(self):
540
+ return self.sampler
541
+
542
+ def sample(
543
+ self,
544
+ domain: Domain,
545
+ config: Optional[Union[List[Dict], Dict]] = None,
546
+ size: int = 1,
547
+ random_state: "RandomState" = None,
548
+ ):
549
+ if not isinstance(random_state, _BackwardsCompatibleNumpyRng):
550
+ random_state = _BackwardsCompatibleNumpyRng(random_state)
551
+
552
+ if self.q == 1:
553
+ return self.sampler.sample(domain, config, size, random_state=random_state)
554
+
555
+ quantized_domain = copy(domain)
556
+ quantized_domain.lower = np.ceil(domain.lower / self.q) * self.q
557
+ quantized_domain.upper = np.floor(domain.upper / self.q) * self.q
558
+ values = self.sampler.sample(
559
+ quantized_domain, config, size, random_state=random_state
560
+ )
561
+ quantized = np.round(np.divide(values, self.q)) * self.q
562
+
563
+ if not isinstance(quantized, np.ndarray):
564
+ return domain.cast(quantized)
565
+ return list(quantized)
566
+
567
+
568
+ @PublicAPI
569
+ def sample_from(func: Callable[[Dict], Any]):
570
+ """Specify that tune should sample configuration values from this function.
571
+
572
+ Arguments:
573
+ func: An callable function to draw a sample from.
574
+ """
575
+ return Function(func)
576
+
577
+
578
+ @PublicAPI
579
+ def uniform(lower: float, upper: float):
580
+ """Sample a float value uniformly between ``lower`` and ``upper``.
581
+
582
+ Sampling from ``tune.uniform(1, 10)`` is equivalent to sampling from
583
+ ``np.random.uniform(1, 10))``
584
+
585
+ """
586
+ return Float(lower, upper).uniform()
587
+
588
+
589
+ @PublicAPI
590
+ def quniform(lower: float, upper: float, q: float):
591
+ """Sample a quantized float value uniformly between ``lower`` and ``upper``.
592
+
593
+ Sampling from ``tune.uniform(1, 10)`` is equivalent to sampling from
594
+ ``np.random.uniform(1, 10))``
595
+
596
+ The value will be quantized, i.e. rounded to an integer increment of ``q``.
597
+ Quantization makes the upper bound inclusive.
598
+
599
+ """
600
+ return Float(lower, upper).uniform().quantized(q)
601
+
602
+
603
+ @PublicAPI
604
+ def loguniform(lower: float, upper: float, base: float = 10):
605
+ """Sugar for sampling in different orders of magnitude.
606
+
607
+ Args:
608
+ lower: Lower boundary of the output interval (e.g. 1e-4)
609
+ upper: Upper boundary of the output interval (e.g. 1e-2)
610
+ base: Base of the log. Defaults to 10.
611
+
612
+ """
613
+ return Float(lower, upper).loguniform(base)
614
+
615
+
616
+ @PublicAPI
617
+ def qloguniform(lower: float, upper: float, q: float, base: float = 10):
618
+ """Sugar for sampling in different orders of magnitude.
619
+
620
+ The value will be quantized, i.e. rounded to an integer increment of ``q``.
621
+
622
+ Quantization makes the upper bound inclusive.
623
+
624
+ Args:
625
+ lower: Lower boundary of the output interval (e.g. 1e-4)
626
+ upper: Upper boundary of the output interval (e.g. 1e-2)
627
+ q: Quantization number. The result will be rounded to an
628
+ integer increment of this value.
629
+ base: Base of the log. Defaults to 10.
630
+
631
+ """
632
+ return Float(lower, upper).loguniform(base).quantized(q)
633
+
634
+
635
+ @PublicAPI
636
+ def choice(categories: Sequence):
637
+ """Sample a categorical value.
638
+
639
+ Sampling from ``tune.choice([1, 2])`` is equivalent to sampling from
640
+ ``np.random.choice([1, 2])``
641
+
642
+ """
643
+ return Categorical(categories).uniform()
644
+
645
+
646
+ @PublicAPI
647
+ def randint(lower: int, upper: int):
648
+ """Sample an integer value uniformly between ``lower`` and ``upper``.
649
+
650
+ ``lower`` is inclusive, ``upper`` is exclusive.
651
+
652
+ Sampling from ``tune.randint(10)`` is equivalent to sampling from
653
+ ``np.random.randint(10)``
654
+
655
+ .. versionchanged:: 1.5.0
656
+ When converting Ray Tune configs to searcher-specific search spaces,
657
+ the lower and upper limits are adjusted to keep compatibility with
658
+ the bounds stated in the docstring above.
659
+
660
+ """
661
+ return Integer(lower, upper).uniform()
662
+
663
+
664
+ @PublicAPI
665
+ def lograndint(lower: int, upper: int, base: float = 10):
666
+ """Sample an integer value log-uniformly between ``lower`` and ``upper``,
667
+ with ``base`` being the base of logarithm.
668
+
669
+ ``lower`` is inclusive, ``upper`` is exclusive.
670
+
671
+ .. versionchanged:: 1.5.0
672
+ When converting Ray Tune configs to searcher-specific search spaces,
673
+ the lower and upper limits are adjusted to keep compatibility with
674
+ the bounds stated in the docstring above.
675
+
676
+ """
677
+ return Integer(lower, upper).loguniform(base)
678
+
679
+
680
+ @PublicAPI
681
+ def qrandint(lower: int, upper: int, q: int = 1):
682
+ """Sample an integer value uniformly between ``lower`` and ``upper``.
683
+
684
+ ``lower`` is inclusive, ``upper`` is also inclusive (!).
685
+
686
+ The value will be quantized, i.e. rounded to an integer increment of ``q``.
687
+ Quantization makes the upper bound inclusive.
688
+
689
+ .. versionchanged:: 1.5.0
690
+ When converting Ray Tune configs to searcher-specific search spaces,
691
+ the lower and upper limits are adjusted to keep compatibility with
692
+ the bounds stated in the docstring above.
693
+
694
+ """
695
+ return Integer(lower, upper).uniform().quantized(q)
696
+
697
+
698
+ @PublicAPI
699
+ def qlograndint(lower: int, upper: int, q: int, base: float = 10):
700
+ """Sample an integer value log-uniformly between ``lower`` and ``upper``,
701
+ with ``base`` being the base of logarithm.
702
+
703
+ ``lower`` is inclusive, ``upper`` is also inclusive (!).
704
+
705
+ The value will be quantized, i.e. rounded to an integer increment of ``q``.
706
+ Quantization makes the upper bound inclusive.
707
+
708
+ .. versionchanged:: 1.5.0
709
+ When converting Ray Tune configs to searcher-specific search spaces,
710
+ the lower and upper limits are adjusted to keep compatibility with
711
+ the bounds stated in the docstring above.
712
+
713
+ """
714
+ return Integer(lower, upper).loguniform(base).quantized(q)
715
+
716
+
717
+ @PublicAPI
718
+ def randn(mean: float = 0.0, sd: float = 1.0):
719
+ """Sample a float value normally with ``mean`` and ``sd``.
720
+
721
+ Args:
722
+ mean: Mean of the normal distribution. Defaults to 0.
723
+ sd: SD of the normal distribution. Defaults to 1.
724
+
725
+ """
726
+ return Float(None, None).normal(mean, sd)
727
+
728
+
729
+ @PublicAPI
730
+ def qrandn(mean: float, sd: float, q: float):
731
+ """Sample a float value normally with ``mean`` and ``sd``.
732
+
733
+ The value will be quantized, i.e. rounded to an integer increment of ``q``.
734
+
735
+ Args:
736
+ mean: Mean of the normal distribution.
737
+ sd: SD of the normal distribution.
738
+ q: Quantization number. The result will be rounded to an
739
+ integer increment of this value.
740
+
741
+ """
742
+ return Float(None, None).normal(mean, sd).quantized(q)
.venv/lib/python3.11/site-packages/ray/tune/search/search_algorithm.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import TYPE_CHECKING, Dict, List, Optional, Union
2
+
3
+ from ray.util.annotations import DeveloperAPI
4
+
5
+ if TYPE_CHECKING:
6
+ from ray.tune.experiment import Experiment
7
+
8
+
9
+ @DeveloperAPI
10
+ class SearchAlgorithm:
11
+ """Interface of an event handler API for hyperparameter search.
12
+
13
+ Unlike TrialSchedulers, SearchAlgorithms will not have the ability
14
+ to modify the execution (i.e., stop and pause trials).
15
+
16
+ Trials added manually (i.e., via the Client API) will also notify
17
+ this class upon new events, so custom search algorithms should
18
+ maintain a list of trials ID generated from this class.
19
+
20
+ See also: `ray.tune.search.BasicVariantGenerator`.
21
+ """
22
+
23
+ _finished = False
24
+
25
+ _metric = None
26
+
27
+ @property
28
+ def metric(self):
29
+ return self._metric
30
+
31
+ def set_search_properties(
32
+ self, metric: Optional[str], mode: Optional[str], config: Dict, **spec
33
+ ) -> bool:
34
+ """Pass search properties to search algorithm.
35
+
36
+ This method acts as an alternative to instantiating search algorithms
37
+ with their own specific search spaces. Instead they can accept a
38
+ Tune config through this method.
39
+
40
+ The search algorithm will usually pass this method to their
41
+ ``Searcher`` instance.
42
+
43
+ Args:
44
+ metric: Metric to optimize
45
+ mode: One of ["min", "max"]. Direction to optimize.
46
+ config: Tune config dict.
47
+ **spec: Any kwargs for forward compatiblity.
48
+ Info like Experiment.PUBLIC_KEYS is provided through here.
49
+ """
50
+ if self._metric and metric:
51
+ return False
52
+ if metric:
53
+ self._metric = metric
54
+ return True
55
+
56
+ @property
57
+ def total_samples(self):
58
+ """Get number of total trials to be generated"""
59
+ return 0
60
+
61
+ def add_configurations(
62
+ self, experiments: Union["Experiment", List["Experiment"], Dict[str, Dict]]
63
+ ):
64
+ """Tracks given experiment specifications.
65
+
66
+ Arguments:
67
+ experiments: Experiments to run.
68
+ """
69
+ raise NotImplementedError
70
+
71
+ def next_trial(self):
72
+ """Returns single Trial object to be queued into the TrialRunner.
73
+
74
+ Returns:
75
+ trial: Returns a Trial object.
76
+ """
77
+ raise NotImplementedError
78
+
79
+ def on_trial_result(self, trial_id: str, result: Dict):
80
+ """Called on each intermediate result returned by a trial.
81
+
82
+ This will only be called when the trial is in the RUNNING state.
83
+
84
+ Arguments:
85
+ trial_id: Identifier for the trial.
86
+ result: Result dictionary.
87
+ """
88
+ pass
89
+
90
+ def on_trial_complete(
91
+ self, trial_id: str, result: Optional[Dict] = None, error: bool = False
92
+ ):
93
+ """Notification for the completion of trial.
94
+
95
+ Arguments:
96
+ trial_id: Identifier for the trial.
97
+ result: Defaults to None. A dict will
98
+ be provided with this notification when the trial is in
99
+ the RUNNING state AND either completes naturally or
100
+ by manual termination.
101
+ error: Defaults to False. True if the trial is in
102
+ the RUNNING state and errors.
103
+ """
104
+ pass
105
+
106
+ def is_finished(self) -> bool:
107
+ """Returns True if no trials left to be queued into TrialRunner.
108
+
109
+ Can return True before all trials have finished executing.
110
+ """
111
+ return self._finished
112
+
113
+ def set_finished(self):
114
+ """Marks the search algorithm as finished."""
115
+ self._finished = True
116
+
117
+ def has_checkpoint(self, dirpath: str) -> bool:
118
+ """Should return False if restoring is not implemented."""
119
+ return False
120
+
121
+ def save_to_dir(self, dirpath: str, **kwargs):
122
+ """Saves a search algorithm."""
123
+ pass
124
+
125
+ def restore_from_dir(self, dirpath: str):
126
+ """Restores a search algorithm along with its wrapped state."""
127
+ pass
.venv/lib/python3.11/site-packages/ray/tune/search/search_generator.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import logging
3
+ from typing import Dict, List, Optional, Union
4
+
5
+ from ray.tune.error import TuneError
6
+ from ray.tune.experiment import Experiment, Trial, _convert_to_experiment_list
7
+ from ray.tune.experiment.config_parser import _create_trial_from_spec, _make_parser
8
+ from ray.tune.search.search_algorithm import SearchAlgorithm
9
+ from ray.tune.search.searcher import Searcher
10
+ from ray.tune.search.util import _set_search_properties_backwards_compatible
11
+ from ray.tune.search.variant_generator import _resolve_nested_dict, format_vars
12
+ from ray.tune.utils.util import (
13
+ _atomic_save,
14
+ _load_newest_checkpoint,
15
+ flatten_dict,
16
+ merge_dicts,
17
+ )
18
+ from ray.util.annotations import DeveloperAPI
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+
23
+ def _warn_on_repeater(searcher, total_samples):
24
+ from ray.tune.search.repeater import _warn_num_samples
25
+
26
+ _warn_num_samples(searcher, total_samples)
27
+
28
+
29
+ @DeveloperAPI
30
+ class SearchGenerator(SearchAlgorithm):
31
+ """Generates trials to be passed to the TrialRunner.
32
+
33
+ Uses the provided ``searcher`` object to generate trials. This class
34
+ transparently handles repeating trials with score aggregation
35
+ without embedding logic into the Searcher.
36
+
37
+ Args:
38
+ searcher: Search object that subclasses the Searcher base class. This
39
+ is then used for generating new hyperparameter samples.
40
+ """
41
+
42
+ CKPT_FILE_TMPL = "search_gen_state-{}.json"
43
+
44
+ def __init__(self, searcher: Searcher):
45
+ assert issubclass(
46
+ type(searcher), Searcher
47
+ ), "Searcher should be subclassing Searcher."
48
+ self.searcher = searcher
49
+ self._parser = _make_parser()
50
+ self._experiment = None
51
+ self._counter = 0 # Keeps track of number of trials created.
52
+ self._total_samples = 0 # int: total samples to evaluate.
53
+ self._finished = False
54
+
55
+ @property
56
+ def metric(self):
57
+ return self.searcher.metric
58
+
59
+ def set_search_properties(
60
+ self, metric: Optional[str], mode: Optional[str], config: Dict, **spec
61
+ ) -> bool:
62
+ return _set_search_properties_backwards_compatible(
63
+ self.searcher.set_search_properties, metric, mode, config, **spec
64
+ )
65
+
66
+ @property
67
+ def total_samples(self):
68
+ return self._total_samples
69
+
70
+ def add_configurations(
71
+ self, experiments: Union[Experiment, List[Experiment], Dict[str, Dict]]
72
+ ):
73
+ """Registers experiment specifications.
74
+
75
+ Arguments:
76
+ experiments: Experiments to run.
77
+ """
78
+ assert not self._experiment
79
+ logger.debug("added configurations")
80
+ experiment_list = _convert_to_experiment_list(experiments)
81
+ assert (
82
+ len(experiment_list) == 1
83
+ ), "SearchAlgorithms can only support 1 experiment at a time."
84
+ self._experiment = experiment_list[0]
85
+ experiment_spec = self._experiment.spec
86
+ self._total_samples = self._experiment.spec.get("num_samples", 1)
87
+
88
+ _warn_on_repeater(self.searcher, self._total_samples)
89
+ if "run" not in experiment_spec:
90
+ raise TuneError("Must specify `run` in {}".format(experiment_spec))
91
+
92
+ def next_trial(self):
93
+ """Provides one Trial object to be queued into the TrialRunner.
94
+
95
+ Returns:
96
+ Trial: Returns a single trial.
97
+ """
98
+ if not self.is_finished():
99
+ return self.create_trial_if_possible(self._experiment.spec)
100
+ return None
101
+
102
+ def create_trial_if_possible(self, experiment_spec: Dict) -> Optional[Trial]:
103
+ logger.debug("creating trial")
104
+ trial_id = Trial.generate_id()
105
+ suggested_config = self.searcher.suggest(trial_id)
106
+ if suggested_config == Searcher.FINISHED:
107
+ self._finished = True
108
+ logger.debug("Searcher has finished.")
109
+ return
110
+
111
+ if suggested_config is None:
112
+ return
113
+ spec = copy.deepcopy(experiment_spec)
114
+ spec["config"] = merge_dicts(spec["config"], copy.deepcopy(suggested_config))
115
+
116
+ # Create a new trial_id if duplicate trial is created
117
+ flattened_config = _resolve_nested_dict(spec["config"])
118
+ self._counter += 1
119
+ tag = "{0}_{1}".format(str(self._counter), format_vars(flattened_config))
120
+ trial = _create_trial_from_spec(
121
+ spec,
122
+ self._parser,
123
+ evaluated_params=flatten_dict(suggested_config),
124
+ experiment_tag=tag,
125
+ trial_id=trial_id,
126
+ )
127
+ return trial
128
+
129
+ def on_trial_result(self, trial_id: str, result: Dict):
130
+ """Notifies the underlying searcher."""
131
+ self.searcher.on_trial_result(trial_id, result)
132
+
133
+ def on_trial_complete(
134
+ self, trial_id: str, result: Optional[Dict] = None, error: bool = False
135
+ ):
136
+ self.searcher.on_trial_complete(trial_id=trial_id, result=result, error=error)
137
+
138
+ def is_finished(self) -> bool:
139
+ return self._counter >= self._total_samples or self._finished
140
+
141
+ def get_state(self) -> Dict:
142
+ return {
143
+ "counter": self._counter,
144
+ "total_samples": self._total_samples,
145
+ "finished": self._finished,
146
+ "experiment": self._experiment,
147
+ }
148
+
149
+ def set_state(self, state: Dict):
150
+ self._counter = state["counter"]
151
+ self._total_samples = state["total_samples"]
152
+ self._finished = state["finished"]
153
+ self._experiment = state["experiment"]
154
+
155
+ def has_checkpoint(self, dirpath: str):
156
+ return bool(_load_newest_checkpoint(dirpath, self.CKPT_FILE_TMPL.format("*")))
157
+
158
+ def save_to_dir(self, dirpath: str, session_str: str):
159
+ """Saves self + searcher to dir.
160
+
161
+ Separates the "searcher" from its wrappers (concurrency, repeating).
162
+ This allows the user to easily restore a given searcher.
163
+
164
+ The save operation is atomic (write/swap).
165
+
166
+ Args:
167
+ dirpath: Filepath to experiment dir.
168
+ session_str: Unique identifier of the current run
169
+ session.
170
+ """
171
+ searcher = self.searcher
172
+ search_alg_state = self.get_state()
173
+ while hasattr(searcher, "searcher"):
174
+ searcher_name = type(searcher).__name__
175
+ if searcher_name in search_alg_state:
176
+ logger.warning(
177
+ "There was a duplicate when saving {}. "
178
+ "Restore may not work properly.".format(searcher_name)
179
+ )
180
+ else:
181
+ search_alg_state["name:" + searcher_name] = searcher.get_state()
182
+ searcher = searcher.searcher
183
+ base_searcher = searcher
184
+ # We save the base searcher separately for users to easily
185
+ # separate the searcher.
186
+ base_searcher.save_to_dir(dirpath, session_str)
187
+ _atomic_save(
188
+ state=search_alg_state,
189
+ checkpoint_dir=dirpath,
190
+ file_name=self.CKPT_FILE_TMPL.format(session_str),
191
+ tmp_file_name=".tmp_search_generator_ckpt",
192
+ )
193
+
194
+ def restore_from_dir(self, dirpath: str):
195
+ """Restores self + searcher + search wrappers from dirpath."""
196
+
197
+ searcher = self.searcher
198
+ search_alg_state = _load_newest_checkpoint(
199
+ dirpath, self.CKPT_FILE_TMPL.format("*")
200
+ )
201
+ if not search_alg_state:
202
+ raise RuntimeError("Unable to find checkpoint in {}.".format(dirpath))
203
+ while hasattr(searcher, "searcher"):
204
+ searcher_name = "name:" + type(searcher).__name__
205
+ if searcher_name not in search_alg_state:
206
+ names = [
207
+ key.split("name:")[1]
208
+ for key in search_alg_state
209
+ if key.startswith("name:")
210
+ ]
211
+ logger.warning(
212
+ "{} was not found in the experiment "
213
+ "state when restoring. Found {}.".format(searcher_name, names)
214
+ )
215
+ else:
216
+ searcher.set_state(search_alg_state.pop(searcher_name))
217
+ searcher = searcher.searcher
218
+ base_searcher = searcher
219
+
220
+ logger.debug(f"searching base {base_searcher}")
221
+ base_searcher.restore_from_dir(dirpath)
222
+ self.set_state(search_alg_state)
.venv/lib/python3.11/site-packages/ray/tune/search/searcher.py ADDED
@@ -0,0 +1,597 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import glob
3
+ import logging
4
+ import os
5
+ import warnings
6
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
7
+
8
+ from ray.air._internal.usage import tag_searcher
9
+ from ray.tune.search.util import _set_search_properties_backwards_compatible
10
+ from ray.util.annotations import DeveloperAPI, PublicAPI
11
+ from ray.util.debug import log_once
12
+
13
+ if TYPE_CHECKING:
14
+ from ray.tune.analysis import ExperimentAnalysis
15
+ from ray.tune.experiment import Trial
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ @DeveloperAPI
21
+ class Searcher:
22
+ """Abstract class for wrapping suggesting algorithms.
23
+
24
+ Custom algorithms can extend this class easily by overriding the
25
+ `suggest` method provide generated parameters for the trials.
26
+
27
+ Any subclass that implements ``__init__`` must also call the
28
+ constructor of this class: ``super(Subclass, self).__init__(...)``.
29
+
30
+ To track suggestions and their corresponding evaluations, the method
31
+ `suggest` will be passed a trial_id, which will be used in
32
+ subsequent notifications.
33
+
34
+ Not all implementations support multi objectives.
35
+
36
+ Note to Tune developers: If a new searcher is added, please update
37
+ `air/_internal/usage.py`.
38
+
39
+ Args:
40
+ metric: The training result objective value attribute. If
41
+ list then list of training result objective value attributes
42
+ mode: If string One of {min, max}. If list then
43
+ list of max and min, determines whether objective is minimizing
44
+ or maximizing the metric attribute. Must match type of metric.
45
+
46
+ .. code-block:: python
47
+
48
+ class ExampleSearch(Searcher):
49
+ def __init__(self, metric="mean_loss", mode="min", **kwargs):
50
+ super(ExampleSearch, self).__init__(
51
+ metric=metric, mode=mode, **kwargs)
52
+ self.optimizer = Optimizer()
53
+ self.configurations = {}
54
+
55
+ def suggest(self, trial_id):
56
+ configuration = self.optimizer.query()
57
+ self.configurations[trial_id] = configuration
58
+
59
+ def on_trial_complete(self, trial_id, result, **kwargs):
60
+ configuration = self.configurations[trial_id]
61
+ if result and self.metric in result:
62
+ self.optimizer.update(configuration, result[self.metric])
63
+
64
+ tuner = tune.Tuner(
65
+ trainable_function,
66
+ tune_config=tune.TuneConfig(
67
+ search_alg=ExampleSearch()
68
+ )
69
+ )
70
+ tuner.fit()
71
+
72
+
73
+ """
74
+
75
+ FINISHED = "FINISHED"
76
+ CKPT_FILE_TMPL = "searcher-state-{}.pkl"
77
+
78
+ def __init__(
79
+ self,
80
+ metric: Optional[str] = None,
81
+ mode: Optional[str] = None,
82
+ ):
83
+ tag_searcher(self)
84
+ self._metric = metric
85
+ self._mode = mode
86
+
87
+ if not mode or not metric:
88
+ # Early return to avoid assertions
89
+ return
90
+
91
+ assert isinstance(
92
+ metric, type(mode)
93
+ ), "metric and mode must be of the same type"
94
+ if isinstance(mode, str):
95
+ assert mode in ["min", "max"], "if `mode` is a str must be 'min' or 'max'!"
96
+ elif isinstance(mode, list):
97
+ assert len(mode) == len(metric), "Metric and mode must be the same length"
98
+ assert all(
99
+ mod in ["min", "max", "obs"] for mod in mode
100
+ ), "All of mode must be 'min' or 'max' or 'obs'!"
101
+ else:
102
+ raise ValueError("Mode most either be a list or string")
103
+
104
+ def set_search_properties(
105
+ self, metric: Optional[str], mode: Optional[str], config: Dict, **spec
106
+ ) -> bool:
107
+ """Pass search properties to searcher.
108
+
109
+ This method acts as an alternative to instantiating search algorithms
110
+ with their own specific search spaces. Instead they can accept a
111
+ Tune config through this method. A searcher should return ``True``
112
+ if setting the config was successful, or ``False`` if it was
113
+ unsuccessful, e.g. when the search space has already been set.
114
+
115
+ Args:
116
+ metric: Metric to optimize
117
+ mode: One of ["min", "max"]. Direction to optimize.
118
+ config: Tune config dict.
119
+ **spec: Any kwargs for forward compatiblity.
120
+ Info like Experiment.PUBLIC_KEYS is provided through here.
121
+ """
122
+ return False
123
+
124
+ def on_trial_result(self, trial_id: str, result: Dict) -> None:
125
+ """Optional notification for result during training.
126
+
127
+ Note that by default, the result dict may include NaNs or
128
+ may not include the optimization metric. It is up to the
129
+ subclass implementation to preprocess the result to
130
+ avoid breaking the optimization process.
131
+
132
+ Args:
133
+ trial_id: A unique string ID for the trial.
134
+ result: Dictionary of metrics for current training progress.
135
+ Note that the result dict may include NaNs or
136
+ may not include the optimization metric. It is up to the
137
+ subclass implementation to preprocess the result to
138
+ avoid breaking the optimization process.
139
+ """
140
+ pass
141
+
142
+ def on_trial_complete(
143
+ self, trial_id: str, result: Optional[Dict] = None, error: bool = False
144
+ ) -> None:
145
+ """Notification for the completion of trial.
146
+
147
+ Typically, this method is used for notifying the underlying
148
+ optimizer of the result.
149
+
150
+ Args:
151
+ trial_id: A unique string ID for the trial.
152
+ result: Dictionary of metrics for current training progress.
153
+ Note that the result dict may include NaNs or
154
+ may not include the optimization metric. It is up to the
155
+ subclass implementation to preprocess the result to
156
+ avoid breaking the optimization process. Upon errors, this
157
+ may also be None.
158
+ error: True if the training process raised an error.
159
+
160
+ """
161
+ raise NotImplementedError
162
+
163
+ def suggest(self, trial_id: str) -> Optional[Dict]:
164
+ """Queries the algorithm to retrieve the next set of parameters.
165
+
166
+ Arguments:
167
+ trial_id: Trial ID used for subsequent notifications.
168
+
169
+ Returns:
170
+ dict | FINISHED | None: Configuration for a trial, if possible.
171
+ If FINISHED is returned, Tune will be notified that
172
+ no more suggestions/configurations will be provided.
173
+ If None is returned, Tune will skip the querying of the
174
+ searcher for this step.
175
+
176
+ """
177
+ raise NotImplementedError
178
+
179
+ def add_evaluated_point(
180
+ self,
181
+ parameters: Dict,
182
+ value: float,
183
+ error: bool = False,
184
+ pruned: bool = False,
185
+ intermediate_values: Optional[List[float]] = None,
186
+ ):
187
+ """Pass results from a point that has been evaluated separately.
188
+
189
+ This method allows for information from outside the
190
+ suggest - on_trial_complete loop to be passed to the search
191
+ algorithm.
192
+ This functionality depends on the underlying search algorithm
193
+ and may not be always available.
194
+
195
+ Args:
196
+ parameters: Parameters used for the trial.
197
+ value: Metric value obtained in the trial.
198
+ error: True if the training process raised an error.
199
+ pruned: True if trial was pruned.
200
+ intermediate_values: List of metric values for
201
+ intermediate iterations of the result. None if not
202
+ applicable.
203
+
204
+ """
205
+ raise NotImplementedError
206
+
207
+ def add_evaluated_trials(
208
+ self,
209
+ trials_or_analysis: Union["Trial", List["Trial"], "ExperimentAnalysis"],
210
+ metric: str,
211
+ ):
212
+ """Pass results from trials that have been evaluated separately.
213
+
214
+ This method allows for information from outside the
215
+ suggest - on_trial_complete loop to be passed to the search
216
+ algorithm.
217
+ This functionality depends on the underlying search algorithm
218
+ and may not be always available (same as ``add_evaluated_point``.)
219
+
220
+ Args:
221
+ trials_or_analysis: Trials to pass results form to the searcher.
222
+ metric: Metric name reported by trials used for
223
+ determining the objective value.
224
+
225
+ """
226
+ if self.add_evaluated_point == Searcher.add_evaluated_point:
227
+ raise NotImplementedError
228
+
229
+ # lazy imports to avoid circular dependencies
230
+ from ray.tune.analysis import ExperimentAnalysis
231
+ from ray.tune.experiment import Trial
232
+ from ray.tune.result import DONE
233
+
234
+ if isinstance(trials_or_analysis, (list, tuple)):
235
+ trials = trials_or_analysis
236
+ elif isinstance(trials_or_analysis, Trial):
237
+ trials = [trials_or_analysis]
238
+ elif isinstance(trials_or_analysis, ExperimentAnalysis):
239
+ trials = trials_or_analysis.trials
240
+ else:
241
+ raise NotImplementedError(
242
+ "Expected input to be a `Trial`, a list of `Trial`s, or "
243
+ f"`ExperimentAnalysis`, got: {trials_or_analysis}"
244
+ )
245
+
246
+ any_trial_had_metric = False
247
+
248
+ def trial_to_points(trial: Trial) -> Dict[str, Any]:
249
+ nonlocal any_trial_had_metric
250
+ has_trial_been_pruned = (
251
+ trial.status == Trial.TERMINATED
252
+ and not trial.last_result.get(DONE, False)
253
+ )
254
+ has_trial_finished = (
255
+ trial.status == Trial.TERMINATED and trial.last_result.get(DONE, False)
256
+ )
257
+ if not any_trial_had_metric:
258
+ any_trial_had_metric = (
259
+ metric in trial.last_result and has_trial_finished
260
+ )
261
+ if Trial.TERMINATED and metric not in trial.last_result:
262
+ return None
263
+ return dict(
264
+ parameters=trial.config,
265
+ value=trial.last_result.get(metric, None),
266
+ error=trial.status == Trial.ERROR,
267
+ pruned=has_trial_been_pruned,
268
+ intermediate_values=None, # we do not save those
269
+ )
270
+
271
+ for trial in trials:
272
+ kwargs = trial_to_points(trial)
273
+ if kwargs:
274
+ self.add_evaluated_point(**kwargs)
275
+
276
+ if not any_trial_had_metric:
277
+ warnings.warn(
278
+ "No completed trial returned the specified metric. "
279
+ "Make sure the name you have passed is correct. "
280
+ )
281
+
282
+ def save(self, checkpoint_path: str):
283
+ """Save state to path for this search algorithm.
284
+
285
+ Args:
286
+ checkpoint_path: File where the search algorithm
287
+ state is saved. This path should be used later when
288
+ restoring from file.
289
+
290
+ Example:
291
+
292
+ .. code-block:: python
293
+
294
+ search_alg = Searcher(...)
295
+
296
+ tuner = tune.Tuner(
297
+ cost,
298
+ tune_config=tune.TuneConfig(
299
+ search_alg=search_alg,
300
+ num_samples=5
301
+ ),
302
+ param_space=config
303
+ )
304
+ results = tuner.fit()
305
+
306
+ search_alg.save("./my_favorite_path.pkl")
307
+
308
+ .. versionchanged:: 0.8.7
309
+ Save is automatically called by `Tuner().fit()`. You can use
310
+ `Tuner().restore()` to restore from an experiment directory
311
+ such as `~/ray_results/trainable`.
312
+
313
+ """
314
+ raise NotImplementedError
315
+
316
+ def restore(self, checkpoint_path: str):
317
+ """Restore state for this search algorithm
318
+
319
+
320
+ Args:
321
+ checkpoint_path: File where the search algorithm
322
+ state is saved. This path should be the same
323
+ as the one provided to "save".
324
+
325
+ Example:
326
+
327
+ .. code-block:: python
328
+
329
+ search_alg.save("./my_favorite_path.pkl")
330
+
331
+ search_alg2 = Searcher(...)
332
+ search_alg2 = ConcurrencyLimiter(search_alg2, 1)
333
+ search_alg2.restore(checkpoint_path)
334
+ tuner = tune.Tuner(
335
+ cost,
336
+ tune_config=tune.TuneConfig(
337
+ search_alg=search_alg2,
338
+ num_samples=5
339
+ ),
340
+ )
341
+ tuner.fit()
342
+
343
+ """
344
+ raise NotImplementedError
345
+
346
+ def set_max_concurrency(self, max_concurrent: int) -> bool:
347
+ """Set max concurrent trials this searcher can run.
348
+
349
+ This method will be called on the wrapped searcher by the
350
+ ``ConcurrencyLimiter``. It is intended to allow for searchers
351
+ which have custom, internal logic handling max concurrent trials
352
+ to inherit the value passed to ``ConcurrencyLimiter``.
353
+
354
+ If this method returns False, it signifies that no special
355
+ logic for handling this case is present in the searcher.
356
+
357
+ Args:
358
+ max_concurrent: Number of maximum concurrent trials.
359
+ """
360
+ return False
361
+
362
+ def get_state(self) -> Dict:
363
+ raise NotImplementedError
364
+
365
+ def set_state(self, state: Dict):
366
+ raise NotImplementedError
367
+
368
+ def save_to_dir(self, checkpoint_dir: str, session_str: str = "default"):
369
+ """Automatically saves the given searcher to the checkpoint_dir.
370
+
371
+ This is automatically used by Tuner().fit() during a Tune job.
372
+
373
+ Args:
374
+ checkpoint_dir: Filepath to experiment dir.
375
+ session_str: Unique identifier of the current run
376
+ session.
377
+ """
378
+ tmp_search_ckpt_path = os.path.join(checkpoint_dir, ".tmp_searcher_ckpt")
379
+ success = True
380
+ try:
381
+ self.save(tmp_search_ckpt_path)
382
+ except NotImplementedError:
383
+ if log_once("suggest:save_to_dir"):
384
+ logger.warning("save not implemented for Searcher. Skipping save.")
385
+ success = False
386
+
387
+ if success and os.path.exists(tmp_search_ckpt_path):
388
+ os.replace(
389
+ tmp_search_ckpt_path,
390
+ os.path.join(checkpoint_dir, self.CKPT_FILE_TMPL.format(session_str)),
391
+ )
392
+
393
+ def restore_from_dir(self, checkpoint_dir: str):
394
+ """Restores the state of a searcher from a given checkpoint_dir.
395
+
396
+ Typically, you should use this function to restore from an
397
+ experiment directory such as `~/ray_results/trainable`.
398
+
399
+ .. code-block:: python
400
+
401
+ tuner = tune.Tuner(
402
+ cost,
403
+ run_config=train.RunConfig(
404
+ name=self.experiment_name,
405
+ storage_path="~/my_results",
406
+ ),
407
+ tune_config=tune.TuneConfig(
408
+ search_alg=search_alg,
409
+ num_samples=5
410
+ ),
411
+ param_space=config
412
+ )
413
+ tuner.fit()
414
+
415
+ search_alg2 = Searcher()
416
+ search_alg2.restore_from_dir(
417
+ os.path.join("~/my_results", self.experiment_name)
418
+ """
419
+
420
+ pattern = self.CKPT_FILE_TMPL.format("*")
421
+ full_paths = glob.glob(os.path.join(checkpoint_dir, pattern))
422
+ if not full_paths:
423
+ raise RuntimeError(
424
+ "Searcher unable to find checkpoint in {}".format(checkpoint_dir)
425
+ ) # TODO
426
+ most_recent_checkpoint = max(full_paths)
427
+ self.restore(most_recent_checkpoint)
428
+
429
+ @property
430
+ def metric(self) -> str:
431
+ """The training result objective value attribute."""
432
+ return self._metric
433
+
434
+ @property
435
+ def mode(self) -> str:
436
+ """Specifies if minimizing or maximizing the metric."""
437
+ return self._mode
438
+
439
+
440
+ @PublicAPI
441
+ class ConcurrencyLimiter(Searcher):
442
+ """A wrapper algorithm for limiting the number of concurrent trials.
443
+
444
+ Certain Searchers have their own internal logic for limiting
445
+ the number of concurrent trials. If such a Searcher is passed to a
446
+ ``ConcurrencyLimiter``, the ``max_concurrent`` of the
447
+ ``ConcurrencyLimiter`` will override the ``max_concurrent`` value
448
+ of the Searcher. The ``ConcurrencyLimiter`` will then let the
449
+ Searcher's internal logic take over.
450
+
451
+ Args:
452
+ searcher: Searcher object that the
453
+ ConcurrencyLimiter will manage.
454
+ max_concurrent: Maximum concurrent samples from the underlying
455
+ searcher.
456
+ batch: Whether to wait for all concurrent samples
457
+ to finish before updating the underlying searcher.
458
+
459
+ Example:
460
+
461
+ .. code-block:: python
462
+
463
+ from ray.tune.search import ConcurrencyLimiter
464
+ search_alg = HyperOptSearch(metric="accuracy")
465
+ search_alg = ConcurrencyLimiter(search_alg, max_concurrent=2)
466
+ tuner = tune.Tuner(
467
+ trainable_function,
468
+ tune_config=tune.TuneConfig(
469
+ search_alg=search_alg
470
+ ),
471
+ )
472
+ tuner.fit()
473
+
474
+ """
475
+
476
+ def __init__(self, searcher: Searcher, max_concurrent: int, batch: bool = False):
477
+ assert type(max_concurrent) is int and max_concurrent > 0
478
+ self.searcher = searcher
479
+ self.max_concurrent = max_concurrent
480
+ self.batch = batch
481
+ self.live_trials = set()
482
+ self.num_unfinished_live_trials = 0
483
+ self.cached_results = {}
484
+ self._limit_concurrency = True
485
+
486
+ if not isinstance(searcher, Searcher):
487
+ raise RuntimeError(
488
+ f"The `ConcurrencyLimiter` only works with `Searcher` "
489
+ f"objects (got {type(searcher)}). Please try to pass "
490
+ f"`max_concurrent` to the search generator directly."
491
+ )
492
+
493
+ self._set_searcher_max_concurrency()
494
+
495
+ super(ConcurrencyLimiter, self).__init__(
496
+ metric=self.searcher.metric, mode=self.searcher.mode
497
+ )
498
+
499
+ def _set_searcher_max_concurrency(self):
500
+ # If the searcher has special logic for handling max concurrency,
501
+ # we do not do anything inside the ConcurrencyLimiter
502
+ self._limit_concurrency = not self.searcher.set_max_concurrency(
503
+ self.max_concurrent
504
+ )
505
+
506
+ def set_max_concurrency(self, max_concurrent: int) -> bool:
507
+ # Determine if this behavior is acceptable, or if it should
508
+ # raise an exception.
509
+ self.max_concurrent = max_concurrent
510
+ return True
511
+
512
+ def set_search_properties(
513
+ self, metric: Optional[str], mode: Optional[str], config: Dict, **spec
514
+ ) -> bool:
515
+ self._set_searcher_max_concurrency()
516
+ return _set_search_properties_backwards_compatible(
517
+ self.searcher.set_search_properties, metric, mode, config, **spec
518
+ )
519
+
520
+ def suggest(self, trial_id: str) -> Optional[Dict]:
521
+ if not self._limit_concurrency:
522
+ return self.searcher.suggest(trial_id)
523
+
524
+ assert (
525
+ trial_id not in self.live_trials
526
+ ), f"Trial ID {trial_id} must be unique: already found in set."
527
+ if len(self.live_trials) >= self.max_concurrent:
528
+ logger.debug(
529
+ f"Not providing a suggestion for {trial_id} due to "
530
+ "concurrency limit: %s/%s.",
531
+ len(self.live_trials),
532
+ self.max_concurrent,
533
+ )
534
+ return
535
+
536
+ suggestion = self.searcher.suggest(trial_id)
537
+ if suggestion not in (None, Searcher.FINISHED):
538
+ self.live_trials.add(trial_id)
539
+ self.num_unfinished_live_trials += 1
540
+ return suggestion
541
+
542
+ def on_trial_complete(
543
+ self, trial_id: str, result: Optional[Dict] = None, error: bool = False
544
+ ):
545
+ if not self._limit_concurrency:
546
+ return self.searcher.on_trial_complete(trial_id, result=result, error=error)
547
+
548
+ if trial_id not in self.live_trials:
549
+ return
550
+ elif self.batch:
551
+ self.cached_results[trial_id] = (result, error)
552
+ self.num_unfinished_live_trials -= 1
553
+ if self.num_unfinished_live_trials <= 0:
554
+ # Update the underlying searcher once the
555
+ # full batch is completed.
556
+ for trial_id, (result, error) in self.cached_results.items():
557
+ self.searcher.on_trial_complete(
558
+ trial_id, result=result, error=error
559
+ )
560
+ self.live_trials.remove(trial_id)
561
+ self.cached_results = {}
562
+ self.num_unfinished_live_trials = 0
563
+ else:
564
+ return
565
+ else:
566
+ self.searcher.on_trial_complete(trial_id, result=result, error=error)
567
+ self.live_trials.remove(trial_id)
568
+ self.num_unfinished_live_trials -= 1
569
+
570
+ def on_trial_result(self, trial_id: str, result: Dict) -> None:
571
+ self.searcher.on_trial_result(trial_id, result)
572
+
573
+ def add_evaluated_point(
574
+ self,
575
+ parameters: Dict,
576
+ value: float,
577
+ error: bool = False,
578
+ pruned: bool = False,
579
+ intermediate_values: Optional[List[float]] = None,
580
+ ):
581
+ return self.searcher.add_evaluated_point(
582
+ parameters, value, error, pruned, intermediate_values
583
+ )
584
+
585
+ def get_state(self) -> Dict:
586
+ state = self.__dict__.copy()
587
+ del state["searcher"]
588
+ return copy.deepcopy(state)
589
+
590
+ def set_state(self, state: Dict):
591
+ self.__dict__.update(state)
592
+
593
+ def save(self, checkpoint_path: str):
594
+ self.searcher.save(checkpoint_path)
595
+
596
+ def restore(self, checkpoint_path: str):
597
+ self.searcher.restore(checkpoint_path)
.venv/lib/python3.11/site-packages/ray/tune/search/util.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Dict, Optional
3
+
4
+ logger = logging.getLogger(__name__)
5
+
6
+
7
+ def _set_search_properties_backwards_compatible(
8
+ set_search_properties_func,
9
+ metric: Optional[str],
10
+ mode: Optional[str],
11
+ config: Dict,
12
+ **spec
13
+ ) -> bool:
14
+ """Wraps around set_search_properties() so that it is backward compatible.
15
+
16
+ Also outputs a warning to encourage custom searchers to be updated.
17
+ """
18
+ try:
19
+ return set_search_properties_func(metric, mode, config, **spec)
20
+ except TypeError as e:
21
+ if str(e).startswith(
22
+ "set_search_properties() got an unexpected keyword argument"
23
+ ):
24
+ logger.warning(
25
+ "Please update custom Searcher to take in function signature "
26
+ "as ``def set_search_properties(metric, mode, config, "
27
+ "**spec) -> bool``."
28
+ )
29
+ return set_search_properties_func(metric, mode, config)
30
+ else:
31
+ raise e
.venv/lib/python3.11/site-packages/ray/tune/search/variant_generator.py ADDED
@@ -0,0 +1,523 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import logging
3
+ import random
4
+ import re
5
+ from collections.abc import Mapping
6
+ from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple
7
+
8
+ import numpy
9
+
10
+ from ray.tune.search.sample import Categorical, Domain, Function, RandomState
11
+ from ray.util.annotations import DeveloperAPI, PublicAPI
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ @DeveloperAPI
17
+ def generate_variants(
18
+ unresolved_spec: Dict,
19
+ constant_grid_search: bool = False,
20
+ random_state: "RandomState" = None,
21
+ ) -> Generator[Tuple[Dict, Dict], None, None]:
22
+ """Generates variants from a spec (dict) with unresolved values.
23
+
24
+ There are two types of unresolved values:
25
+
26
+ Grid search: These define a grid search over values. For example, the
27
+ following grid search values in a spec will produce six distinct
28
+ variants in combination:
29
+
30
+ "activation": grid_search(["relu", "tanh"])
31
+ "learning_rate": grid_search([1e-3, 1e-4, 1e-5])
32
+
33
+ Lambda functions: These are evaluated to produce a concrete value, and
34
+ can express dependencies or conditional distributions between values.
35
+ They can also be used to express random search (e.g., by calling
36
+ into the `random` or `np` module).
37
+
38
+ "cpu": lambda spec: spec.config.num_workers
39
+ "batch_size": lambda spec: random.uniform(1, 1000)
40
+
41
+ Finally, to support defining specs in plain JSON / YAML, grid search
42
+ and lambda functions can also be defined alternatively as follows:
43
+
44
+ "activation": {"grid_search": ["relu", "tanh"]}
45
+ "cpu": {"eval": "spec.config.num_workers"}
46
+
47
+ Use `format_vars` to format the returned dict of hyperparameters.
48
+
49
+ Yields:
50
+ (Dict of resolved variables, Spec object)
51
+ """
52
+ for resolved_vars, spec in _generate_variants_internal(
53
+ unresolved_spec,
54
+ constant_grid_search=constant_grid_search,
55
+ random_state=random_state,
56
+ ):
57
+ assert not _unresolved_values(spec)
58
+ yield resolved_vars, spec
59
+
60
+
61
+ @PublicAPI(stability="beta")
62
+ def grid_search(values: Iterable) -> Dict[str, Iterable]:
63
+ """Specify a grid of values to search over.
64
+
65
+ Values specified in a grid search are guaranteed to be sampled.
66
+
67
+ If multiple grid search variables are defined, they are combined with the
68
+ combinatorial product. This means every possible combination of values will
69
+ be sampled.
70
+
71
+ Example:
72
+
73
+ >>> from ray import tune
74
+ >>> param_space={
75
+ ... "x": tune.grid_search([10, 20]),
76
+ ... "y": tune.grid_search(["a", "b", "c"])
77
+ ... }
78
+
79
+ This will create a grid of 6 samples:
80
+ ``{"x": 10, "y": "a"}``, ``{"x": 10, "y": "b"}``, etc.
81
+
82
+ When specifying ``num_samples`` in the
83
+ :class:`TuneConfig <ray.tune.tune_config.TuneConfig>`, this will specify
84
+ the number of random samples per grid search combination.
85
+
86
+ For instance, in the example above, if ``num_samples=4``,
87
+ a total of 24 trials will be started -
88
+ 4 trials for each of the 6 grid search combinations.
89
+
90
+ Args:
91
+ values: An iterable whose parameters will be used for creating a trial grid.
92
+
93
+ """
94
+ return {"grid_search": values}
95
+
96
+
97
+ _STANDARD_IMPORTS = {
98
+ "random": random,
99
+ "np": numpy,
100
+ }
101
+
102
+ _MAX_RESOLUTION_PASSES = 20
103
+
104
+
105
+ def _resolve_nested_dict(nested_dict: Dict) -> Dict[Tuple, Any]:
106
+ """Flattens a nested dict by joining keys into tuple of paths.
107
+
108
+ Can then be passed into `format_vars`.
109
+ """
110
+ res = {}
111
+ for k, v in nested_dict.items():
112
+ if isinstance(v, dict):
113
+ for k_, v_ in _resolve_nested_dict(v).items():
114
+ res[(k,) + k_] = v_
115
+ else:
116
+ res[(k,)] = v
117
+ return res
118
+
119
+
120
+ @DeveloperAPI
121
+ def format_vars(resolved_vars: Dict) -> str:
122
+ """Format variables to be used as experiment tags.
123
+
124
+ Experiment tags are used in directory names, so this method makes sure
125
+ the resulting tags can be legally used in directory names on all systems.
126
+
127
+ The input to this function is a dict of the form
128
+ ``{("nested", "config", "path"): "value"}``. The output will be a comma
129
+ separated string of the form ``last_key=value``, so in this example
130
+ ``path=value``.
131
+
132
+ Note that the sanitizing implies that empty strings are possible return
133
+ values. This is expected and acceptable, as it is not a common case and
134
+ the resulting directory names will still be valid.
135
+
136
+ Args:
137
+ resolved_vars: Dictionary mapping from config path tuples to a value.
138
+
139
+ Returns:
140
+ Comma-separated key=value string.
141
+ """
142
+ vars = resolved_vars.copy()
143
+ # TrialRunner already has these in the experiment_tag
144
+ for v in ["run", "env", "resources_per_trial"]:
145
+ vars.pop(v, None)
146
+
147
+ return ",".join(
148
+ f"{_clean_value(k[-1])}={_clean_value(v)}" for k, v in sorted(vars.items())
149
+ )
150
+
151
+
152
+ def _flatten_resolved_vars(resolved_vars: Dict) -> Dict:
153
+ """Formats the resolved variable dict into a mapping of (str -> value)."""
154
+ flattened_resolved_vars_dict = {}
155
+ for pieces, value in resolved_vars.items():
156
+ if pieces[0] == "config":
157
+ pieces = pieces[1:]
158
+ pieces = [str(piece) for piece in pieces]
159
+ flattened_resolved_vars_dict["/".join(pieces)] = value
160
+ return flattened_resolved_vars_dict
161
+
162
+
163
+ def _clean_value(value: Any) -> str:
164
+ """Format floats and replace invalid string characters with ``_``."""
165
+ if isinstance(value, float):
166
+ return f"{value:.4f}"
167
+ else:
168
+ # Define an invalid alphabet, which is the inverse of the
169
+ # stated regex characters
170
+ invalid_alphabet = r"[^a-zA-Z0-9_-]+"
171
+ return re.sub(invalid_alphabet, "_", str(value)).strip("_")
172
+
173
+
174
+ @DeveloperAPI
175
+ def parse_spec_vars(
176
+ spec: Dict,
177
+ ) -> Tuple[List[Tuple[Tuple, Any]], List[Tuple[Tuple, Any]], List[Tuple[Tuple, Any]]]:
178
+ resolved, unresolved = _split_resolved_unresolved_values(spec)
179
+ resolved_vars = list(resolved.items())
180
+
181
+ if not unresolved:
182
+ return resolved_vars, [], []
183
+
184
+ grid_vars = []
185
+ domain_vars = []
186
+ for path, value in unresolved.items():
187
+ if value.is_grid():
188
+ grid_vars.append((path, value))
189
+ else:
190
+ domain_vars.append((path, value))
191
+ grid_vars.sort()
192
+
193
+ return resolved_vars, domain_vars, grid_vars
194
+
195
+
196
+ def _count_spec_samples(spec: Dict, num_samples=1) -> int:
197
+ """Count samples for a specific spec"""
198
+ _, domain_vars, grid_vars = parse_spec_vars(spec)
199
+ grid_count = 1
200
+ for path, domain in grid_vars:
201
+ grid_count *= len(domain.categories)
202
+ return num_samples * grid_count
203
+
204
+
205
+ def _count_variants(spec: Dict, presets: Optional[List[Dict]] = None) -> int:
206
+ # Helper function: Deep update dictionary
207
+ def deep_update(d, u):
208
+ for k, v in u.items():
209
+ if isinstance(v, Mapping):
210
+ d[k] = deep_update(d.get(k, {}), v)
211
+ else:
212
+ d[k] = v
213
+ return d
214
+
215
+ total_samples = 0
216
+ total_num_samples = spec.get("num_samples", 1)
217
+ # For each preset, overwrite the spec and count the samples generated
218
+ # for this preset
219
+ for preset in presets:
220
+ preset_spec = copy.deepcopy(spec)
221
+ deep_update(preset_spec["config"], preset)
222
+ total_samples += _count_spec_samples(preset_spec, 1)
223
+ total_num_samples -= 1
224
+
225
+ # Add the remaining samples
226
+ if total_num_samples > 0:
227
+ total_samples += _count_spec_samples(spec, total_num_samples)
228
+ return total_samples
229
+
230
+
231
+ def _generate_variants_internal(
232
+ spec: Dict, constant_grid_search: bool = False, random_state: "RandomState" = None
233
+ ) -> Tuple[Dict, Dict]:
234
+ spec = copy.deepcopy(spec)
235
+ _, domain_vars, grid_vars = parse_spec_vars(spec)
236
+
237
+ if not domain_vars and not grid_vars:
238
+ yield {}, spec
239
+ return
240
+
241
+ # Variables to resolve
242
+ to_resolve = domain_vars
243
+
244
+ all_resolved = True
245
+ if constant_grid_search:
246
+ # In this path, we first sample random variables and keep them constant
247
+ # for grid search.
248
+ # `_resolve_domain_vars` will alter `spec` directly
249
+ all_resolved, resolved_vars = _resolve_domain_vars(
250
+ spec, domain_vars, allow_fail=True, random_state=random_state
251
+ )
252
+ if not all_resolved:
253
+ # Not all variables have been resolved, but remove those that have
254
+ # from the `to_resolve` list.
255
+ to_resolve = [(r, d) for r, d in to_resolve if r not in resolved_vars]
256
+ grid_search = _grid_search_generator(spec, grid_vars)
257
+ for resolved_spec in grid_search:
258
+ if not constant_grid_search or not all_resolved:
259
+ # In this path, we sample the remaining random variables
260
+ _, resolved_vars = _resolve_domain_vars(
261
+ resolved_spec, to_resolve, random_state=random_state
262
+ )
263
+
264
+ for resolved, spec in _generate_variants_internal(
265
+ resolved_spec,
266
+ constant_grid_search=constant_grid_search,
267
+ random_state=random_state,
268
+ ):
269
+ for path, value in grid_vars:
270
+ resolved_vars[path] = _get_value(spec, path)
271
+ for k, v in resolved.items():
272
+ if (
273
+ k in resolved_vars
274
+ and v != resolved_vars[k]
275
+ and _is_resolved(resolved_vars[k])
276
+ ):
277
+ raise ValueError(
278
+ "The variable `{}` could not be unambiguously "
279
+ "resolved to a single value. Consider simplifying "
280
+ "your configuration.".format(k)
281
+ )
282
+ resolved_vars[k] = v
283
+ yield resolved_vars, spec
284
+
285
+
286
+ def _get_preset_variants(
287
+ spec: Dict,
288
+ config: Dict,
289
+ constant_grid_search: bool = False,
290
+ random_state: "RandomState" = None,
291
+ ):
292
+ """Get variants according to a spec, initialized with a config.
293
+
294
+ Variables from the spec are overwritten by the variables in the config.
295
+ Thus, we may end up with less sampled parameters.
296
+
297
+ This function also checks if values used to overwrite search space
298
+ parameters are valid, and logs a warning if not.
299
+ """
300
+ spec = copy.deepcopy(spec)
301
+
302
+ resolved, _, _ = parse_spec_vars(config)
303
+
304
+ for path, val in resolved:
305
+ try:
306
+ domain = _get_value(spec["config"], path)
307
+ if isinstance(domain, dict):
308
+ if "grid_search" in domain:
309
+ domain = Categorical(domain["grid_search"])
310
+ else:
311
+ # If users want to overwrite an entire subdict,
312
+ # let them do it.
313
+ domain = None
314
+ except IndexError as exc:
315
+ raise ValueError(
316
+ f"Pre-set config key `{'/'.join(path)}` does not correspond "
317
+ f"to a valid key in the search space definition. Please add "
318
+ f"this path to the `param_space` variable passed to `tune.Tuner()`."
319
+ ) from exc
320
+
321
+ if domain:
322
+ if isinstance(domain, Domain):
323
+ if not domain.is_valid(val):
324
+ logger.warning(
325
+ f"Pre-set value `{val}` is not within valid values of "
326
+ f"parameter `{'/'.join(path)}`: {domain.domain_str}"
327
+ )
328
+ else:
329
+ # domain is actually a fixed value
330
+ if domain != val:
331
+ logger.warning(
332
+ f"Pre-set value `{val}` is not equal to the value of "
333
+ f"parameter `{'/'.join(path)}`: {domain}"
334
+ )
335
+ assign_value(spec["config"], path, val)
336
+
337
+ return _generate_variants_internal(
338
+ spec, constant_grid_search=constant_grid_search, random_state=random_state
339
+ )
340
+
341
+
342
+ @DeveloperAPI
343
+ def assign_value(spec: Dict, path: Tuple, value: Any):
344
+ """Assigns a value to a nested dictionary.
345
+
346
+ Handles the special case of tuples, in which case the tuples
347
+ will be re-constructed to accomodate the updated value.
348
+ """
349
+ parent_spec = None
350
+ parent_key = None
351
+ for k in path[:-1]:
352
+ parent_spec = spec
353
+ parent_key = k
354
+ spec = spec[k]
355
+ key = path[-1]
356
+ if not isinstance(spec, tuple):
357
+ # spec is mutable. Just assign the value.
358
+ spec[key] = value
359
+ else:
360
+ if parent_spec is None:
361
+ raise ValueError("Cannot assign value to a tuple.")
362
+ assert isinstance(key, int), "Tuple key must be an int."
363
+ # Special handling since tuples are immutable.
364
+ parent_spec[parent_key] = spec[:key] + (value,) + spec[key + 1 :]
365
+
366
+
367
+ def _get_value(spec: Dict, path: Tuple) -> Any:
368
+ for k in path:
369
+ spec = spec[k]
370
+ return spec
371
+
372
+
373
+ def _resolve_domain_vars(
374
+ spec: Dict,
375
+ domain_vars: List[Tuple[Tuple, Domain]],
376
+ allow_fail: bool = False,
377
+ random_state: "RandomState" = None,
378
+ ) -> Tuple[bool, Dict]:
379
+ resolved = {}
380
+ error = True
381
+ num_passes = 0
382
+ while error and num_passes < _MAX_RESOLUTION_PASSES:
383
+ num_passes += 1
384
+ error = False
385
+ for path, domain in domain_vars:
386
+ if path in resolved:
387
+ continue
388
+ try:
389
+ value = domain.sample(
390
+ _UnresolvedAccessGuard(spec), random_state=random_state
391
+ )
392
+ except RecursiveDependencyError as e:
393
+ error = e
394
+ except Exception:
395
+ raise ValueError(
396
+ "Failed to evaluate expression: {}: {}".format(path, domain)
397
+ )
398
+ else:
399
+ assign_value(spec, path, value)
400
+ resolved[path] = value
401
+ if error:
402
+ if not allow_fail:
403
+ raise error
404
+ else:
405
+ return False, resolved
406
+ return True, resolved
407
+
408
+
409
+ def _grid_search_generator(
410
+ unresolved_spec: Dict, grid_vars: List
411
+ ) -> Generator[Dict, None, None]:
412
+ value_indices = [0] * len(grid_vars)
413
+
414
+ def increment(i):
415
+ value_indices[i] += 1
416
+ if value_indices[i] >= len(grid_vars[i][1]):
417
+ value_indices[i] = 0
418
+ if i + 1 < len(value_indices):
419
+ return increment(i + 1)
420
+ else:
421
+ return True
422
+ return False
423
+
424
+ if not grid_vars:
425
+ yield unresolved_spec
426
+ return
427
+
428
+ while value_indices[-1] < len(grid_vars[-1][1]):
429
+ spec = copy.deepcopy(unresolved_spec)
430
+ for i, (path, values) in enumerate(grid_vars):
431
+ assign_value(spec, path, values[value_indices[i]])
432
+ yield spec
433
+ if grid_vars:
434
+ done = increment(0)
435
+ if done:
436
+ break
437
+
438
+
439
+ def _is_resolved(v) -> bool:
440
+ resolved, _ = _try_resolve(v)
441
+ return resolved
442
+
443
+
444
+ def _try_resolve(v) -> Tuple[bool, Any]:
445
+ if isinstance(v, Domain):
446
+ # Domain to sample from
447
+ return False, v
448
+ elif isinstance(v, dict) and len(v) == 1 and "eval" in v:
449
+ # Lambda function in eval syntax
450
+ return False, Function(
451
+ lambda spec: eval(v["eval"], _STANDARD_IMPORTS, {"spec": spec})
452
+ )
453
+ elif isinstance(v, dict) and len(v) == 1 and "grid_search" in v:
454
+ # Grid search values
455
+ grid_values = v["grid_search"]
456
+ return False, Categorical(grid_values).grid()
457
+ return True, v
458
+
459
+
460
+ def _split_resolved_unresolved_values(
461
+ spec: Dict,
462
+ ) -> Tuple[Dict[Tuple, Any], Dict[Tuple, Any]]:
463
+ resolved_vars = {}
464
+ unresolved_vars = {}
465
+ for k, v in spec.items():
466
+ resolved, v = _try_resolve(v)
467
+ if not resolved:
468
+ unresolved_vars[(k,)] = v
469
+ elif isinstance(v, dict):
470
+ # Recurse into a dict
471
+ (
472
+ _resolved_children,
473
+ _unresolved_children,
474
+ ) = _split_resolved_unresolved_values(v)
475
+ for path, value in _resolved_children.items():
476
+ resolved_vars[(k,) + path] = value
477
+ for path, value in _unresolved_children.items():
478
+ unresolved_vars[(k,) + path] = value
479
+ elif isinstance(v, (list, tuple)):
480
+ # Recurse into a list
481
+ for i, elem in enumerate(v):
482
+ (
483
+ _resolved_children,
484
+ _unresolved_children,
485
+ ) = _split_resolved_unresolved_values({i: elem})
486
+ for path, value in _resolved_children.items():
487
+ resolved_vars[(k,) + path] = value
488
+ for path, value in _unresolved_children.items():
489
+ unresolved_vars[(k,) + path] = value
490
+ else:
491
+ resolved_vars[(k,)] = v
492
+ return resolved_vars, unresolved_vars
493
+
494
+
495
+ def _unresolved_values(spec: Dict) -> Dict[Tuple, Any]:
496
+ return _split_resolved_unresolved_values(spec)[1]
497
+
498
+
499
+ def _has_unresolved_values(spec: Dict) -> bool:
500
+ return True if _unresolved_values(spec) else False
501
+
502
+
503
+ class _UnresolvedAccessGuard(dict):
504
+ def __init__(self, *args, **kwds):
505
+ super(_UnresolvedAccessGuard, self).__init__(*args, **kwds)
506
+ self.__dict__ = self
507
+
508
+ def __getattribute__(self, item):
509
+ value = dict.__getattribute__(self, item)
510
+ if not _is_resolved(value):
511
+ raise RecursiveDependencyError(
512
+ "`{}` recursively depends on {}".format(item, value)
513
+ )
514
+ elif isinstance(value, dict):
515
+ return _UnresolvedAccessGuard(value)
516
+ else:
517
+ return value
518
+
519
+
520
+ @DeveloperAPI
521
+ class RecursiveDependencyError(Exception):
522
+ def __init__(self, msg: str):
523
+ Exception.__init__(self, msg)
.venv/lib/python3.11/site-packages/ray/tune/stopper/__init__.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ray.tune.stopper.experiment_plateau import ExperimentPlateauStopper
2
+ from ray.tune.stopper.function_stopper import FunctionStopper
3
+ from ray.tune.stopper.maximum_iteration import MaximumIterationStopper
4
+ from ray.tune.stopper.noop import NoopStopper
5
+ from ray.tune.stopper.stopper import CombinedStopper, Stopper
6
+ from ray.tune.stopper.timeout import TimeoutStopper
7
+ from ray.tune.stopper.trial_plateau import TrialPlateauStopper
8
+
9
+ __all__ = [
10
+ "Stopper",
11
+ "CombinedStopper",
12
+ "ExperimentPlateauStopper",
13
+ "FunctionStopper",
14
+ "MaximumIterationStopper",
15
+ "NoopStopper",
16
+ "TimeoutStopper",
17
+ "TrialPlateauStopper",
18
+ ]
.venv/lib/python3.11/site-packages/ray/tune/stopper/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (894 Bytes). View file
 
.venv/lib/python3.11/site-packages/ray/tune/stopper/__pycache__/experiment_plateau.cpython-311.pyc ADDED
Binary file (4.52 kB). View file
 
.venv/lib/python3.11/site-packages/ray/tune/stopper/__pycache__/function_stopper.cpython-311.pyc ADDED
Binary file (2.3 kB). View file