koichi12 commited on
Commit
39cf1df
·
verified ·
1 Parent(s): e8fd6d0

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .venv/lib/python3.11/site-packages/ray/tune/__pycache__/__init__.cpython-311.pyc +0 -0
  2. .venv/lib/python3.11/site-packages/ray/tune/__pycache__/callback.cpython-311.pyc +0 -0
  3. .venv/lib/python3.11/site-packages/ray/tune/__pycache__/constants.cpython-311.pyc +0 -0
  4. .venv/lib/python3.11/site-packages/ray/tune/__pycache__/context.cpython-311.pyc +0 -0
  5. .venv/lib/python3.11/site-packages/ray/tune/__pycache__/error.cpython-311.pyc +0 -0
  6. .venv/lib/python3.11/site-packages/ray/tune/__pycache__/progress_reporter.cpython-311.pyc +0 -0
  7. .venv/lib/python3.11/site-packages/ray/tune/__pycache__/registry.cpython-311.pyc +0 -0
  8. .venv/lib/python3.11/site-packages/ray/tune/__pycache__/resources.cpython-311.pyc +0 -0
  9. .venv/lib/python3.11/site-packages/ray/tune/__pycache__/result.cpython-311.pyc +0 -0
  10. .venv/lib/python3.11/site-packages/ray/tune/__pycache__/result_grid.cpython-311.pyc +0 -0
  11. .venv/lib/python3.11/site-packages/ray/tune/__pycache__/syncer.cpython-311.pyc +0 -0
  12. .venv/lib/python3.11/site-packages/ray/tune/__pycache__/tune.cpython-311.pyc +0 -0
  13. .venv/lib/python3.11/site-packages/ray/tune/__pycache__/tune_config.cpython-311.pyc +0 -0
  14. .venv/lib/python3.11/site-packages/ray/tune/__pycache__/tuner.cpython-311.pyc +0 -0
  15. .venv/lib/python3.11/site-packages/ray/tune/cli/__init__.py +0 -0
  16. .venv/lib/python3.11/site-packages/ray/tune/cli/__pycache__/__init__.cpython-311.pyc +0 -0
  17. .venv/lib/python3.11/site-packages/ray/tune/cli/__pycache__/commands.cpython-311.pyc +0 -0
  18. .venv/lib/python3.11/site-packages/ray/tune/cli/__pycache__/scripts.cpython-311.pyc +0 -0
  19. .venv/lib/python3.11/site-packages/ray/tune/cli/commands.py +306 -0
  20. .venv/lib/python3.11/site-packages/ray/tune/cli/scripts.py +101 -0
  21. .venv/lib/python3.11/site-packages/ray/tune/examples/__pycache__/__init__.cpython-311.pyc +0 -0
  22. .venv/lib/python3.11/site-packages/ray/tune/examples/__pycache__/utils.cpython-311.pyc +0 -0
  23. .venv/lib/python3.11/site-packages/ray/tune/examples/pbt_dcgan_mnist/__init__.py +0 -0
  24. .venv/lib/python3.11/site-packages/ray/tune/examples/pbt_dcgan_mnist/__pycache__/__init__.cpython-311.pyc +0 -0
  25. .venv/lib/python3.11/site-packages/ray/tune/examples/pbt_dcgan_mnist/__pycache__/common.cpython-311.pyc +0 -0
  26. .venv/lib/python3.11/site-packages/ray/tune/examples/pbt_dcgan_mnist/__pycache__/pbt_dcgan_mnist_func.cpython-311.pyc +0 -0
  27. .venv/lib/python3.11/site-packages/ray/tune/examples/pbt_dcgan_mnist/__pycache__/pbt_dcgan_mnist_trainable.cpython-311.pyc +0 -0
  28. .venv/lib/python3.11/site-packages/ray/tune/examples/pbt_dcgan_mnist/common.py +285 -0
  29. .venv/lib/python3.11/site-packages/ray/tune/examples/pbt_dcgan_mnist/pbt_dcgan_mnist_func.py +191 -0
  30. .venv/lib/python3.11/site-packages/ray/tune/examples/pbt_dcgan_mnist/pbt_dcgan_mnist_trainable.py +185 -0
  31. .venv/lib/python3.11/site-packages/ray/tune/experimental/__init__.py +0 -0
  32. .venv/lib/python3.11/site-packages/ray/tune/experimental/__pycache__/__init__.cpython-311.pyc +0 -0
  33. .venv/lib/python3.11/site-packages/ray/tune/experimental/__pycache__/output.cpython-311.pyc +0 -0
  34. .venv/lib/python3.11/site-packages/ray/tune/experimental/output.py +1043 -0
  35. .venv/lib/python3.11/site-packages/ray/tune/logger/__init__.py +32 -0
  36. .venv/lib/python3.11/site-packages/ray/tune/logger/__pycache__/__init__.cpython-311.pyc +0 -0
  37. .venv/lib/python3.11/site-packages/ray/tune/logger/__pycache__/aim.cpython-311.pyc +0 -0
  38. .venv/lib/python3.11/site-packages/ray/tune/logger/__pycache__/comet.cpython-311.pyc +0 -0
  39. .venv/lib/python3.11/site-packages/ray/tune/logger/__pycache__/csv.cpython-311.pyc +0 -0
  40. .venv/lib/python3.11/site-packages/ray/tune/logger/__pycache__/json.cpython-311.pyc +0 -0
  41. .venv/lib/python3.11/site-packages/ray/tune/logger/__pycache__/logger.cpython-311.pyc +0 -0
  42. .venv/lib/python3.11/site-packages/ray/tune/logger/__pycache__/mlflow.cpython-311.pyc +0 -0
  43. .venv/lib/python3.11/site-packages/ray/tune/logger/__pycache__/noop.cpython-311.pyc +0 -0
  44. .venv/lib/python3.11/site-packages/ray/tune/logger/__pycache__/tensorboardx.cpython-311.pyc +0 -0
  45. .venv/lib/python3.11/site-packages/ray/tune/logger/__pycache__/unified.cpython-311.pyc +0 -0
  46. .venv/lib/python3.11/site-packages/ray/tune/logger/__pycache__/wandb.cpython-311.pyc +0 -0
  47. .venv/lib/python3.11/site-packages/ray/tune/logger/aim.py +187 -0
  48. .venv/lib/python3.11/site-packages/ray/tune/logger/comet.py +3 -0
  49. .venv/lib/python3.11/site-packages/ray/tune/logger/csv.py +135 -0
  50. .venv/lib/python3.11/site-packages/ray/tune/logger/json.py +128 -0
.venv/lib/python3.11/site-packages/ray/tune/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (3.31 kB). View file
 
.venv/lib/python3.11/site-packages/ray/tune/__pycache__/callback.cpython-311.pyc ADDED
Binary file (23.4 kB). View file
 
.venv/lib/python3.11/site-packages/ray/tune/__pycache__/constants.cpython-311.pyc ADDED
Binary file (988 Bytes). View file
 
.venv/lib/python3.11/site-packages/ray/tune/__pycache__/context.cpython-311.pyc ADDED
Binary file (6.11 kB). View file
 
.venv/lib/python3.11/site-packages/ray/tune/__pycache__/error.cpython-311.pyc ADDED
Binary file (2.54 kB). View file
 
.venv/lib/python3.11/site-packages/ray/tune/__pycache__/progress_reporter.cpython-311.pyc ADDED
Binary file (73.9 kB). View file
 
.venv/lib/python3.11/site-packages/ray/tune/__pycache__/registry.cpython-311.pyc ADDED
Binary file (15.9 kB). View file
 
.venv/lib/python3.11/site-packages/ray/tune/__pycache__/resources.cpython-311.pyc ADDED
Binary file (3.62 kB). View file
 
.venv/lib/python3.11/site-packages/ray/tune/__pycache__/result.cpython-311.pyc ADDED
Binary file (1.94 kB). View file
 
.venv/lib/python3.11/site-packages/ray/tune/__pycache__/result_grid.cpython-311.pyc ADDED
Binary file (14.1 kB). View file
 
.venv/lib/python3.11/site-packages/ray/tune/__pycache__/syncer.cpython-311.pyc ADDED
Binary file (1.06 kB). View file
 
.venv/lib/python3.11/site-packages/ray/tune/__pycache__/tune.cpython-311.pyc ADDED
Binary file (48 kB). View file
 
.venv/lib/python3.11/site-packages/ray/tune/__pycache__/tune_config.cpython-311.pyc ADDED
Binary file (6.13 kB). View file
 
.venv/lib/python3.11/site-packages/ray/tune/__pycache__/tuner.cpython-311.pyc ADDED
Binary file (18.7 kB). View file
 
.venv/lib/python3.11/site-packages/ray/tune/cli/__init__.py ADDED
File without changes
.venv/lib/python3.11/site-packages/ray/tune/cli/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (185 Bytes). View file
 
.venv/lib/python3.11/site-packages/ray/tune/cli/__pycache__/commands.cpython-311.pyc ADDED
Binary file (15.1 kB). View file
 
.venv/lib/python3.11/site-packages/ray/tune/cli/__pycache__/scripts.cpython-311.pyc ADDED
Binary file (4.5 kB). View file
 
.venv/lib/python3.11/site-packages/ray/tune/cli/commands.py ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import operator
3
+ import os
4
+ import shutil
5
+ import subprocess
6
+ from datetime import datetime
7
+ from pathlib import Path
8
+ from typing import List, Optional
9
+
10
+ import click
11
+ import pandas as pd
12
+ from pandas.api.types import is_numeric_dtype, is_string_dtype
13
+
14
+ from ray._private.thirdparty.tabulate.tabulate import tabulate
15
+ from ray.air.constants import EXPR_RESULT_FILE
16
+ from ray.tune import TuneError
17
+ from ray.tune.analysis import ExperimentAnalysis
18
+ from ray.tune.result import (
19
+ CONFIG_PREFIX,
20
+ DEFAULT_EXPERIMENT_INFO_KEYS,
21
+ DEFAULT_RESULT_KEYS,
22
+ )
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+ EDITOR = os.getenv("EDITOR", "vim")
27
+
28
+ TIMESTAMP_FORMAT = "%Y-%m-%d %H:%M:%S (%A)"
29
+
30
+ DEFAULT_CLI_KEYS = DEFAULT_EXPERIMENT_INFO_KEYS + DEFAULT_RESULT_KEYS
31
+
32
+ DEFAULT_PROJECT_INFO_KEYS = (
33
+ "name",
34
+ "total_trials",
35
+ "last_updated",
36
+ )
37
+
38
+ TERM_WIDTH, TERM_HEIGHT = shutil.get_terminal_size(fallback=(100, 100))
39
+
40
+ OPERATORS = {
41
+ "<": operator.lt,
42
+ "<=": operator.le,
43
+ "==": operator.eq,
44
+ "!=": operator.ne,
45
+ ">=": operator.ge,
46
+ ">": operator.gt,
47
+ }
48
+
49
+
50
+ def _check_tabulate():
51
+ """Checks whether tabulate is installed."""
52
+ if tabulate is None:
53
+ raise ImportError("Tabulate not installed. Please run `pip install tabulate`.")
54
+
55
+
56
+ def print_format_output(dataframe):
57
+ """Prints output of given dataframe to fit into terminal.
58
+
59
+ Returns:
60
+ table: Final outputted dataframe.
61
+ dropped_cols: Columns dropped due to terminal size.
62
+ empty_cols: Empty columns (dropped on default).
63
+ """
64
+ print_df = pd.DataFrame()
65
+ dropped_cols = []
66
+ empty_cols = []
67
+ # column display priority is based on the info_keys passed in
68
+ for i, col in enumerate(dataframe):
69
+ if dataframe[col].isnull().all():
70
+ # Don't add col to print_df if is fully empty
71
+ empty_cols += [col]
72
+ continue
73
+
74
+ print_df[col] = dataframe[col]
75
+ test_table = tabulate(print_df, headers="keys", tablefmt="psql")
76
+ if str(test_table).index("\n") > TERM_WIDTH:
77
+ # Drop all columns beyond terminal width
78
+ print_df.drop(col, axis=1, inplace=True)
79
+ dropped_cols += list(dataframe.columns)[i:]
80
+ break
81
+
82
+ table = tabulate(print_df, headers="keys", tablefmt="psql", showindex="never")
83
+
84
+ print(table)
85
+ if dropped_cols:
86
+ click.secho("Dropped columns: {}".format(dropped_cols), fg="yellow")
87
+ click.secho("Please increase your terminal size to view remaining columns.")
88
+ if empty_cols:
89
+ click.secho("Empty columns: {}".format(empty_cols), fg="yellow")
90
+
91
+ return table, dropped_cols, empty_cols
92
+
93
+
94
+ def list_trials(
95
+ experiment_path: str,
96
+ sort: Optional[List[str]] = None,
97
+ output: Optional[str] = None,
98
+ filter_op: Optional[str] = None,
99
+ info_keys: Optional[List[str]] = None,
100
+ limit: int = None,
101
+ desc: bool = False,
102
+ ):
103
+ """Lists trials in the directory subtree starting at the given path.
104
+
105
+ Args:
106
+ experiment_path: Directory where trials are located.
107
+ Like Experiment.local_dir/Experiment.name/experiment*.json.
108
+ sort: Keys to sort by.
109
+ output: Name of file where output is saved.
110
+ filter_op: Filter operation in the format
111
+ "<column> <operator> <value>".
112
+ info_keys: Keys that are displayed.
113
+ limit: Number of rows to display.
114
+ desc: Sort ascending vs. descending.
115
+ """
116
+ _check_tabulate()
117
+
118
+ try:
119
+ checkpoints_df = ExperimentAnalysis(experiment_path).dataframe() # last result
120
+ except TuneError as e:
121
+ raise click.ClickException("No trial data found!") from e
122
+
123
+ config_prefix = CONFIG_PREFIX + "/"
124
+
125
+ def key_filter(k):
126
+ return k in DEFAULT_CLI_KEYS or k.startswith(config_prefix)
127
+
128
+ col_keys = [k for k in checkpoints_df.columns if key_filter(k)]
129
+
130
+ if info_keys:
131
+ for k in info_keys:
132
+ if k not in checkpoints_df.columns:
133
+ raise click.ClickException(
134
+ "Provided key invalid: {}. "
135
+ "Available keys: {}.".format(k, checkpoints_df.columns)
136
+ )
137
+ col_keys = [k for k in checkpoints_df.columns if k in info_keys]
138
+
139
+ if not col_keys:
140
+ raise click.ClickException("No columns to output.")
141
+
142
+ checkpoints_df = checkpoints_df[col_keys]
143
+ if "last_update_time" in checkpoints_df:
144
+ with pd.option_context("mode.use_inf_as_null", True):
145
+ datetime_series = checkpoints_df["last_update_time"].dropna()
146
+
147
+ datetime_series = datetime_series.apply(
148
+ lambda t: datetime.fromtimestamp(t).strftime(TIMESTAMP_FORMAT)
149
+ )
150
+ checkpoints_df["last_update_time"] = datetime_series
151
+
152
+ if "logdir" in checkpoints_df:
153
+ # logdir often too long to view in table, so drop experiment_path
154
+ checkpoints_df["logdir"] = checkpoints_df["logdir"].str.replace(
155
+ experiment_path, ""
156
+ )
157
+
158
+ if filter_op:
159
+ col, op, val = filter_op.split(" ")
160
+ col_type = checkpoints_df[col].dtype
161
+ if is_numeric_dtype(col_type):
162
+ val = float(val)
163
+ elif is_string_dtype(col_type):
164
+ val = str(val)
165
+ # TODO(Andrew): add support for datetime and boolean
166
+ else:
167
+ raise click.ClickException(
168
+ "Unsupported dtype for {}: {}".format(val, col_type)
169
+ )
170
+ op = OPERATORS[op]
171
+ filtered_index = op(checkpoints_df[col], val)
172
+ checkpoints_df = checkpoints_df[filtered_index]
173
+
174
+ if sort:
175
+ for key in sort:
176
+ if key not in checkpoints_df:
177
+ raise click.ClickException(
178
+ "{} not in: {}".format(key, list(checkpoints_df))
179
+ )
180
+ ascending = not desc
181
+ checkpoints_df = checkpoints_df.sort_values(by=sort, ascending=ascending)
182
+
183
+ if limit:
184
+ checkpoints_df = checkpoints_df[:limit]
185
+
186
+ print_format_output(checkpoints_df)
187
+
188
+ if output:
189
+ file_extension = os.path.splitext(output)[1].lower()
190
+ if file_extension in (".p", ".pkl", ".pickle"):
191
+ checkpoints_df.to_pickle(output)
192
+ elif file_extension == ".csv":
193
+ checkpoints_df.to_csv(output, index=False)
194
+ else:
195
+ raise click.ClickException("Unsupported filetype: {}".format(output))
196
+ click.secho("Output saved at {}".format(output), fg="green")
197
+
198
+
199
+ def list_experiments(
200
+ project_path: str,
201
+ sort: Optional[List[str]] = None,
202
+ output: str = None,
203
+ filter_op: str = None,
204
+ info_keys: Optional[List[str]] = None,
205
+ limit: int = None,
206
+ desc: bool = False,
207
+ ):
208
+ """Lists experiments in the directory subtree.
209
+
210
+ Args:
211
+ project_path: Directory where experiments are located.
212
+ Corresponds to Experiment.local_dir.
213
+ sort: Keys to sort by.
214
+ output: Name of file where output is saved.
215
+ filter_op: Filter operation in the format
216
+ "<column> <operator> <value>".
217
+ info_keys: Keys that are displayed.
218
+ limit: Number of rows to display.
219
+ desc: Sort ascending vs. descending.
220
+ """
221
+ _check_tabulate()
222
+ base, experiment_folders, _ = next(os.walk(project_path))
223
+
224
+ experiment_data_collection = []
225
+
226
+ for experiment_dir in experiment_folders:
227
+ num_trials = sum(
228
+ EXPR_RESULT_FILE in files
229
+ for _, _, files in os.walk(os.path.join(base, experiment_dir))
230
+ )
231
+
232
+ experiment_data = {"name": experiment_dir, "total_trials": num_trials}
233
+ experiment_data_collection.append(experiment_data)
234
+
235
+ if not experiment_data_collection:
236
+ raise click.ClickException("No experiments found!")
237
+
238
+ info_df = pd.DataFrame(experiment_data_collection)
239
+ if not info_keys:
240
+ info_keys = DEFAULT_PROJECT_INFO_KEYS
241
+ col_keys = [k for k in list(info_keys) if k in info_df]
242
+ if not col_keys:
243
+ raise click.ClickException(
244
+ "None of keys {} in experiment data!".format(info_keys)
245
+ )
246
+ info_df = info_df[col_keys]
247
+
248
+ if filter_op:
249
+ col, op, val = filter_op.split(" ")
250
+ col_type = info_df[col].dtype
251
+ if is_numeric_dtype(col_type):
252
+ val = float(val)
253
+ elif is_string_dtype(col_type):
254
+ val = str(val)
255
+ # TODO(Andrew): add support for datetime and boolean
256
+ else:
257
+ raise click.ClickException(
258
+ "Unsupported dtype for {}: {}".format(val, col_type)
259
+ )
260
+ op = OPERATORS[op]
261
+ filtered_index = op(info_df[col], val)
262
+ info_df = info_df[filtered_index]
263
+
264
+ if sort:
265
+ for key in sort:
266
+ if key not in info_df:
267
+ raise click.ClickException("{} not in: {}".format(key, list(info_df)))
268
+ ascending = not desc
269
+ info_df = info_df.sort_values(by=sort, ascending=ascending)
270
+
271
+ if limit:
272
+ info_df = info_df[:limit]
273
+
274
+ print_format_output(info_df)
275
+
276
+ if output:
277
+ file_extension = os.path.splitext(output)[1].lower()
278
+ if file_extension in (".p", ".pkl", ".pickle"):
279
+ info_df.to_pickle(output)
280
+ elif file_extension == ".csv":
281
+ info_df.to_csv(output, index=False)
282
+ else:
283
+ raise click.ClickException("Unsupported filetype: {}".format(output))
284
+ click.secho("Output saved at {}".format(output), fg="green")
285
+
286
+
287
+ def add_note(path: str, filename: str = "note.txt"):
288
+ """Opens a txt file at the given path where user can add and save notes.
289
+
290
+ Args:
291
+ path: Directory where note will be saved.
292
+ filename: Name of note. Defaults to "note.txt"
293
+ """
294
+ path = Path(path).expanduser()
295
+ assert path.is_dir(), "{} is not a valid directory.".format(path)
296
+
297
+ filepath = path / filename
298
+
299
+ try:
300
+ subprocess.call([EDITOR, filepath.as_posix()])
301
+ except Exception as exc:
302
+ click.secho("Editing note failed: {}".format(str(exc)), fg="red")
303
+ if filepath.exists():
304
+ print("Note updated at:", filepath.as_posix())
305
+ else:
306
+ print("Note created at:", filepath.as_posix())
.venv/lib/python3.11/site-packages/ray/tune/cli/scripts.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import click
2
+
3
+ import ray.tune.cli.commands as commands
4
+
5
+
6
+ @click.group()
7
+ def cli():
8
+ pass
9
+
10
+
11
+ @cli.command()
12
+ @click.argument("experiment_path", required=True, type=str)
13
+ @click.option("--sort", default=None, type=str, help="Select which column to sort on.")
14
+ @click.option(
15
+ "--output",
16
+ "-o",
17
+ default=None,
18
+ type=str,
19
+ help="Select file to output information to.",
20
+ )
21
+ @click.option(
22
+ "--filter",
23
+ "filter_op",
24
+ default=None,
25
+ type=str,
26
+ help="Select filter in the format '<column> <operator> <value>'.",
27
+ )
28
+ @click.option(
29
+ "--columns", default=None, type=str, help="Select columns to be displayed."
30
+ )
31
+ @click.option(
32
+ "--limit", default=None, type=int, help="Select number of rows to display."
33
+ )
34
+ @click.option("--desc", default=False, type=bool, help="Sort ascending vs. descending.")
35
+ def list_trials(experiment_path, sort, output, filter_op, columns, limit, desc):
36
+ """Lists trials in the directory subtree starting at the given path."""
37
+ if sort:
38
+ sort = sort.split(",")
39
+ if columns:
40
+ columns = columns.split(",")
41
+ commands.list_trials(experiment_path, sort, output, filter_op, columns, limit, desc)
42
+
43
+
44
+ @cli.command()
45
+ @click.argument("project_path", required=True, type=str)
46
+ @click.option("--sort", default=None, type=str, help="Select which column to sort on.")
47
+ @click.option(
48
+ "--output",
49
+ "-o",
50
+ default=None,
51
+ type=str,
52
+ help="Select file to output information to.",
53
+ )
54
+ @click.option(
55
+ "--filter",
56
+ "filter_op",
57
+ default=None,
58
+ type=str,
59
+ help="Select filter in the format '<column> <operator> <value>'.",
60
+ )
61
+ @click.option(
62
+ "--columns", default=None, type=str, help="Select columns to be displayed."
63
+ )
64
+ @click.option(
65
+ "--limit", default=None, type=int, help="Select number of rows to display."
66
+ )
67
+ @click.option("--desc", default=False, type=bool, help="Sort ascending vs. descending.")
68
+ def list_experiments(project_path, sort, output, filter_op, columns, limit, desc):
69
+ """Lists experiments in the directory subtree."""
70
+ if sort:
71
+ sort = sort.split(",")
72
+ if columns:
73
+ columns = columns.split(",")
74
+ commands.list_experiments(
75
+ project_path, sort, output, filter_op, columns, limit, desc
76
+ )
77
+
78
+
79
+ @cli.command()
80
+ @click.argument("path", required=True, type=str)
81
+ @click.option(
82
+ "--filename", default="note.txt", type=str, help="Specify filename for note."
83
+ )
84
+ def add_note(path, filename):
85
+ """Adds user notes as a text file at the given path."""
86
+ commands.add_note(path, filename)
87
+
88
+
89
+ cli.add_command(list_trials, name="ls")
90
+ cli.add_command(list_trials, name="list-trials")
91
+ cli.add_command(list_experiments, name="lsx")
92
+ cli.add_command(list_experiments, name="list-experiments")
93
+ cli.add_command(add_note, name="add-note")
94
+
95
+
96
+ def main():
97
+ return cli()
98
+
99
+
100
+ if __name__ == "__main__":
101
+ main()
.venv/lib/python3.11/site-packages/ray/tune/examples/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (190 Bytes). View file
 
.venv/lib/python3.11/site-packages/ray/tune/examples/__pycache__/utils.cpython-311.pyc ADDED
Binary file (1.49 kB). View file
 
.venv/lib/python3.11/site-packages/ray/tune/examples/pbt_dcgan_mnist/__init__.py ADDED
File without changes
.venv/lib/python3.11/site-packages/ray/tune/examples/pbt_dcgan_mnist/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (206 Bytes). View file
 
.venv/lib/python3.11/site-packages/ray/tune/examples/pbt_dcgan_mnist/__pycache__/common.cpython-311.pyc ADDED
Binary file (16.3 kB). View file
 
.venv/lib/python3.11/site-packages/ray/tune/examples/pbt_dcgan_mnist/__pycache__/pbt_dcgan_mnist_func.cpython-311.pyc ADDED
Binary file (8.83 kB). View file
 
.venv/lib/python3.11/site-packages/ray/tune/examples/pbt_dcgan_mnist/__pycache__/pbt_dcgan_mnist_trainable.cpython-311.pyc ADDED
Binary file (9.53 kB). View file
 
.venv/lib/python3.11/site-packages/ray/tune/examples/pbt_dcgan_mnist/common.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import matplotlib.animation as animation
4
+ import matplotlib.pyplot as plt
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.parallel
9
+ import torch.utils.data
10
+ import torchvision.datasets as dset
11
+ import torchvision.transforms as transforms
12
+ import torchvision.utils as vutils
13
+ from scipy.stats import entropy
14
+ from torch.autograd import Variable
15
+ from torch.nn import functional as F
16
+
17
+ import ray
18
+
19
+ # Training parameters
20
+ workers = 2
21
+ batch_size = 64
22
+ image_size = 32
23
+
24
+ # Number of channels in the training images. For color images this is 3
25
+ nc = 1
26
+
27
+ # Size of z latent vector (i.e. size of generator input)
28
+ nz = 100
29
+
30
+ # Size of feature maps in generator
31
+ ngf = 32
32
+
33
+ # Size of feature maps in discriminator
34
+ ndf = 32
35
+
36
+ # Beta1 hyperparam for Adam optimizers
37
+ beta1 = 0.5
38
+
39
+ # iterations of actual training in each Trainable _train
40
+ train_iterations_per_step = 5
41
+
42
+ MODEL_PATH = os.path.expanduser("~/.ray/models/mnist_cnn.pt")
43
+
44
+
45
+ def get_data_loader(data_dir="~/data"):
46
+ dataset = dset.MNIST(
47
+ root=data_dir,
48
+ download=True,
49
+ transform=transforms.Compose(
50
+ [
51
+ transforms.Resize(image_size),
52
+ transforms.ToTensor(),
53
+ transforms.Normalize((0.5,), (0.5,)),
54
+ ]
55
+ ),
56
+ )
57
+
58
+ # Create the dataloader
59
+ dataloader = torch.utils.data.DataLoader(
60
+ dataset, batch_size=batch_size, shuffle=True, num_workers=workers
61
+ )
62
+
63
+ return dataloader
64
+
65
+
66
+ # __GANmodel_begin__
67
+ # custom weights initialization called on netG and netD
68
+ def weights_init(m):
69
+ classname = m.__class__.__name__
70
+ if classname.find("Conv") != -1:
71
+ nn.init.normal_(m.weight.data, 0.0, 0.02)
72
+ elif classname.find("BatchNorm") != -1:
73
+ nn.init.normal_(m.weight.data, 1.0, 0.02)
74
+ nn.init.constant_(m.bias.data, 0)
75
+
76
+
77
+ # Generator Code
78
+ class Generator(nn.Module):
79
+ def __init__(self):
80
+ super(Generator, self).__init__()
81
+ self.main = nn.Sequential(
82
+ # input is Z, going into a convolution
83
+ nn.ConvTranspose2d(nz, ngf * 4, 4, 1, 0, bias=False),
84
+ nn.BatchNorm2d(ngf * 4),
85
+ nn.ReLU(True),
86
+ nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
87
+ nn.BatchNorm2d(ngf * 2),
88
+ nn.ReLU(True),
89
+ nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
90
+ nn.BatchNorm2d(ngf),
91
+ nn.ReLU(True),
92
+ nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
93
+ nn.Tanh(),
94
+ )
95
+
96
+ def forward(self, input):
97
+ return self.main(input)
98
+
99
+
100
+ class Discriminator(nn.Module):
101
+ def __init__(self):
102
+ super(Discriminator, self).__init__()
103
+ self.main = nn.Sequential(
104
+ nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
105
+ nn.LeakyReLU(0.2, inplace=True),
106
+ nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
107
+ nn.BatchNorm2d(ndf * 2),
108
+ nn.LeakyReLU(0.2, inplace=True),
109
+ nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
110
+ nn.BatchNorm2d(ndf * 4),
111
+ nn.LeakyReLU(0.2, inplace=True),
112
+ nn.Conv2d(ndf * 4, 1, 4, 1, 0, bias=False),
113
+ nn.Sigmoid(),
114
+ )
115
+
116
+ def forward(self, input):
117
+ return self.main(input)
118
+
119
+
120
+ # __GANmodel_end__
121
+
122
+
123
+ # __INCEPTION_SCORE_begin__
124
+ class Net(nn.Module):
125
+ """
126
+ LeNet for MNist classification, used for inception_score
127
+ """
128
+
129
+ def __init__(self):
130
+ super(Net, self).__init__()
131
+ self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
132
+ self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
133
+ self.conv2_drop = nn.Dropout2d()
134
+ self.fc1 = nn.Linear(320, 50)
135
+ self.fc2 = nn.Linear(50, 10)
136
+
137
+ def forward(self, x):
138
+ x = F.relu(F.max_pool2d(self.conv1(x), 2))
139
+ x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
140
+ x = x.view(-1, 320)
141
+ x = F.relu(self.fc1(x))
142
+ x = F.dropout(x, training=self.training)
143
+ x = self.fc2(x)
144
+ return F.log_softmax(x, dim=1)
145
+
146
+
147
+ def inception_score(imgs, mnist_model_ref, batch_size=32, splits=1):
148
+ N = len(imgs)
149
+ dtype = torch.FloatTensor
150
+ dataloader = torch.utils.data.DataLoader(imgs, batch_size=batch_size)
151
+ cm = ray.get(mnist_model_ref) # Get the mnist model from Ray object store.
152
+ up = nn.Upsample(size=(28, 28), mode="bilinear").type(dtype)
153
+
154
+ def get_pred(x):
155
+ x = up(x)
156
+ x = cm(x)
157
+ return F.softmax(x).data.cpu().numpy()
158
+
159
+ preds = np.zeros((N, 10))
160
+ for i, batch in enumerate(dataloader, 0):
161
+ batch = batch.type(dtype)
162
+ batchv = Variable(batch)
163
+ batch_size_i = batch.size()[0]
164
+ preds[i * batch_size : i * batch_size + batch_size_i] = get_pred(batchv)
165
+
166
+ # Now compute the mean kl-div
167
+ split_scores = []
168
+ for k in range(splits):
169
+ part = preds[k * (N // splits) : (k + 1) * (N // splits), :]
170
+ py = np.mean(part, axis=0)
171
+ scores = []
172
+ for i in range(part.shape[0]):
173
+ pyx = part[i, :]
174
+ scores.append(entropy(pyx, py))
175
+ split_scores.append(np.exp(np.mean(scores)))
176
+
177
+ return np.mean(split_scores), np.std(split_scores)
178
+
179
+
180
+ # __INCEPTION_SCORE_end__
181
+
182
+
183
+ def train_func(
184
+ netD,
185
+ netG,
186
+ optimG,
187
+ optimD,
188
+ criterion,
189
+ dataloader,
190
+ iteration,
191
+ device,
192
+ mnist_model_ref,
193
+ ):
194
+ real_label = 1
195
+ fake_label = 0
196
+
197
+ for i, data in enumerate(dataloader, 0):
198
+ if i >= train_iterations_per_step:
199
+ break
200
+
201
+ netD.zero_grad()
202
+ real_cpu = data[0].to(device)
203
+ b_size = real_cpu.size(0)
204
+ label = torch.full((b_size,), real_label, dtype=torch.float, device=device)
205
+ output = netD(real_cpu).view(-1)
206
+ errD_real = criterion(output, label)
207
+ errD_real.backward()
208
+ D_x = output.mean().item()
209
+
210
+ noise = torch.randn(b_size, nz, 1, 1, device=device)
211
+ fake = netG(noise)
212
+ label.fill_(fake_label)
213
+ output = netD(fake.detach()).view(-1)
214
+ errD_fake = criterion(output, label)
215
+ errD_fake.backward()
216
+ D_G_z1 = output.mean().item()
217
+ errD = errD_real + errD_fake
218
+ optimD.step()
219
+
220
+ netG.zero_grad()
221
+ label.fill_(real_label)
222
+ output = netD(fake).view(-1)
223
+ errG = criterion(output, label)
224
+ errG.backward()
225
+ D_G_z2 = output.mean().item()
226
+ optimG.step()
227
+
228
+ is_score, is_std = inception_score(fake, mnist_model_ref)
229
+
230
+ # Output training stats
231
+ if iteration % 10 == 0:
232
+ print(
233
+ "[%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z))"
234
+ ": %.4f / %.4f \tInception score: %.4f"
235
+ % (
236
+ iteration,
237
+ len(dataloader),
238
+ errD.item(),
239
+ errG.item(),
240
+ D_x,
241
+ D_G_z1,
242
+ D_G_z2,
243
+ is_score,
244
+ )
245
+ )
246
+
247
+ return errG.item(), errD.item(), is_score
248
+
249
+
250
+ def plot_images(dataloader):
251
+ # Plot some training images
252
+ real_batch = next(iter(dataloader))
253
+ plt.figure(figsize=(8, 8))
254
+ plt.axis("off")
255
+ plt.title("Original Images")
256
+ plt.imshow(
257
+ np.transpose(
258
+ vutils.make_grid(real_batch[0][:64], padding=2, normalize=True).cpu(),
259
+ (1, 2, 0),
260
+ )
261
+ )
262
+
263
+ plt.show()
264
+
265
+
266
+ def demo_gan(checkpoint_paths):
267
+ img_list = []
268
+ fixed_noise = torch.randn(64, nz, 1, 1)
269
+ for path in checkpoint_paths:
270
+ checkpoint_dict = torch.load(os.path.join(path, "checkpoint.pt"))
271
+
272
+ loadedG = Generator()
273
+ loadedG.load_state_dict(checkpoint_dict["netGmodel"])
274
+ with torch.no_grad():
275
+ fake = loadedG(fixed_noise).detach().cpu()
276
+ img_list.append(vutils.make_grid(fake, padding=2, normalize=True))
277
+
278
+ fig = plt.figure(figsize=(8, 8))
279
+ plt.axis("off")
280
+ ims = [[plt.imshow(np.transpose(i, (1, 2, 0)), animated=True)] for i in img_list]
281
+ ani = animation.ArtistAnimation(
282
+ fig, ims, interval=1000, repeat_delay=1000, blit=True
283
+ )
284
+ ani.save("./generated.gif", writer="imagemagick", dpi=72)
285
+ plt.show()
.venv/lib/python3.11/site-packages/ray/tune/examples/pbt_dcgan_mnist/pbt_dcgan_mnist_func.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ """
3
+ Example of training DCGAN on MNIST using PBT with Tune's function API.
4
+ """
5
+ import argparse
6
+ import os
7
+ import tempfile
8
+
9
+ import numpy as np
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.parallel
13
+ import torch.optim as optim
14
+ import torch.utils.data
15
+ from filelock import FileLock
16
+
17
+ import ray
18
+ from ray import train, tune
19
+ from ray.train import Checkpoint
20
+ from ray.tune.examples.pbt_dcgan_mnist.common import (
21
+ MODEL_PATH,
22
+ Discriminator,
23
+ Generator,
24
+ Net,
25
+ beta1,
26
+ demo_gan,
27
+ get_data_loader,
28
+ plot_images,
29
+ train_func,
30
+ weights_init,
31
+ )
32
+ from ray.tune.schedulers import PopulationBasedTraining
33
+
34
+
35
+ # __Train_begin__
36
+ def dcgan_train(config):
37
+ use_cuda = config.get("use_gpu") and torch.cuda.is_available()
38
+ device = torch.device("cuda" if use_cuda else "cpu")
39
+ netD = Discriminator().to(device)
40
+ netD.apply(weights_init)
41
+ netG = Generator().to(device)
42
+ netG.apply(weights_init)
43
+ criterion = nn.BCELoss()
44
+ optimizerD = optim.Adam(
45
+ netD.parameters(), lr=config.get("lr", 0.01), betas=(beta1, 0.999)
46
+ )
47
+ optimizerG = optim.Adam(
48
+ netG.parameters(), lr=config.get("lr", 0.01), betas=(beta1, 0.999)
49
+ )
50
+ with FileLock(os.path.expanduser("~/ray_results/.data.lock")):
51
+ dataloader = get_data_loader()
52
+
53
+ step = 1
54
+ checkpoint = train.get_checkpoint()
55
+ if checkpoint:
56
+ with checkpoint.as_directory() as checkpoint_dir:
57
+ checkpoint_dict = torch.load(os.path.join(checkpoint_dir, "checkpoint.pt"))
58
+ netD.load_state_dict(checkpoint_dict["netDmodel"])
59
+ netG.load_state_dict(checkpoint_dict["netGmodel"])
60
+ optimizerD.load_state_dict(checkpoint_dict["optimD"])
61
+ optimizerG.load_state_dict(checkpoint_dict["optimG"])
62
+ # Note: Make sure to increment the loaded step by 1 to get the
63
+ # current step.
64
+ last_step = checkpoint_dict["step"]
65
+ step = last_step + 1
66
+
67
+ # NOTE: It's important to set the optimizer learning rates
68
+ # again, since we want to explore the parameters passed in by PBT.
69
+ # Without this, we would continue using the exact same
70
+ # configuration as the trial whose checkpoint we are exploiting.
71
+ if "netD_lr" in config:
72
+ for param_group in optimizerD.param_groups:
73
+ param_group["lr"] = config["netD_lr"]
74
+ if "netG_lr" in config:
75
+ for param_group in optimizerG.param_groups:
76
+ param_group["lr"] = config["netG_lr"]
77
+
78
+ while True:
79
+ lossG, lossD, is_score = train_func(
80
+ netD,
81
+ netG,
82
+ optimizerG,
83
+ optimizerD,
84
+ criterion,
85
+ dataloader,
86
+ step,
87
+ device,
88
+ config["mnist_model_ref"],
89
+ )
90
+ metrics = {"lossg": lossG, "lossd": lossD, "is_score": is_score}
91
+
92
+ if step % config["checkpoint_interval"] == 0:
93
+ with tempfile.TemporaryDirectory() as tmpdir:
94
+ torch.save(
95
+ {
96
+ "netDmodel": netD.state_dict(),
97
+ "netGmodel": netG.state_dict(),
98
+ "optimD": optimizerD.state_dict(),
99
+ "optimG": optimizerG.state_dict(),
100
+ "step": step,
101
+ },
102
+ os.path.join(tmpdir, "checkpoint.pt"),
103
+ )
104
+ train.report(metrics, checkpoint=Checkpoint.from_directory(tmpdir))
105
+ else:
106
+ train.report(metrics)
107
+
108
+ step += 1
109
+
110
+
111
+ # __Train_end__
112
+
113
+
114
+ def download_mnist_cnn():
115
+ import urllib.request
116
+
117
+ # Download a pre-trained MNIST model for inception score calculation.
118
+ # This is a tiny model (<100kb).
119
+ if not os.path.exists(MODEL_PATH):
120
+ print("downloading model")
121
+ os.makedirs(os.path.dirname(MODEL_PATH), exist_ok=True)
122
+ urllib.request.urlretrieve(
123
+ "https://github.com/ray-project/ray/raw/master/python/ray/tune/"
124
+ "examples/pbt_dcgan_mnist/mnist_cnn.pt",
125
+ MODEL_PATH,
126
+ )
127
+ return MODEL_PATH
128
+
129
+
130
+ if __name__ == "__main__":
131
+ parser = argparse.ArgumentParser()
132
+ parser.add_argument(
133
+ "--smoke-test", action="store_true", help="Finish quickly for testing"
134
+ )
135
+ parser.add_argument(
136
+ "--data-dir", type=str, default="~/data/", help="Set the path of the dataset."
137
+ )
138
+ args, _ = parser.parse_known_args()
139
+ ray.init()
140
+
141
+ download_mnist_cnn()
142
+
143
+ dataloader = get_data_loader(args.data_dir)
144
+ if not args.smoke_test:
145
+ plot_images(dataloader)
146
+
147
+ # __tune_begin__
148
+
149
+ # load the pretrained mnist classification model for inception_score
150
+ mnist_cnn = Net()
151
+ mnist_cnn.load_state_dict(torch.load(MODEL_PATH))
152
+ mnist_cnn.eval()
153
+ # Put the model in Ray object store.
154
+ mnist_model_ref = ray.put(mnist_cnn)
155
+
156
+ scheduler = PopulationBasedTraining(
157
+ perturbation_interval=5,
158
+ hyperparam_mutations={
159
+ # distribution for resampling
160
+ "netG_lr": lambda: np.random.uniform(1e-2, 1e-5),
161
+ "netD_lr": lambda: np.random.uniform(1e-2, 1e-5),
162
+ },
163
+ )
164
+
165
+ tune_iter = 5 if args.smoke_test else 300
166
+ tuner = tune.Tuner(
167
+ dcgan_train,
168
+ run_config=train.RunConfig(
169
+ name="pbt_dcgan_mnist",
170
+ stop={"training_iteration": tune_iter},
171
+ verbose=1,
172
+ ),
173
+ tune_config=tune.TuneConfig(
174
+ metric="is_score",
175
+ mode="max",
176
+ num_samples=8,
177
+ scheduler=scheduler,
178
+ ),
179
+ param_space={
180
+ "netG_lr": tune.choice([0.0001, 0.0002, 0.0005]),
181
+ "netD_lr": tune.choice([0.0001, 0.0002, 0.0005]),
182
+ "mnist_model_ref": mnist_model_ref,
183
+ },
184
+ )
185
+ results = tuner.fit()
186
+ # __tune_end__
187
+
188
+ # demo of the trained Generators
189
+ if not args.smoke_test:
190
+ checkpoint_paths = [result.checkpoint.to_directory() for result in results]
191
+ demo_gan(checkpoint_paths)
.venv/lib/python3.11/site-packages/ray/tune/examples/pbt_dcgan_mnist/pbt_dcgan_mnist_trainable.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ """
3
+ Example of training DCGAN on MNIST using PBT with Tune's Trainable Class
4
+ API.
5
+ """
6
+ import argparse
7
+ import os
8
+ import random
9
+
10
+ import numpy as np
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.parallel
14
+ import torch.optim as optim
15
+ import torch.utils.data
16
+ from filelock import FileLock
17
+
18
+ import ray
19
+ from ray import train, tune
20
+ from ray.tune.examples.pbt_dcgan_mnist.common import (
21
+ MODEL_PATH,
22
+ Discriminator,
23
+ Generator,
24
+ Net,
25
+ beta1,
26
+ demo_gan,
27
+ get_data_loader,
28
+ plot_images,
29
+ train_func,
30
+ weights_init,
31
+ )
32
+ from ray.tune.schedulers import PopulationBasedTraining
33
+
34
+
35
+ # __Trainable_begin__
36
+ class PytorchTrainable(tune.Trainable):
37
+ def setup(self, config):
38
+ use_cuda = config.get("use_gpu") and torch.cuda.is_available()
39
+ self.device = torch.device("cuda" if use_cuda else "cpu")
40
+ self.netD = Discriminator().to(self.device)
41
+ self.netD.apply(weights_init)
42
+ self.netG = Generator().to(self.device)
43
+ self.netG.apply(weights_init)
44
+ self.criterion = nn.BCELoss()
45
+ self.optimizerD = optim.Adam(
46
+ self.netD.parameters(), lr=config.get("lr", 0.01), betas=(beta1, 0.999)
47
+ )
48
+ self.optimizerG = optim.Adam(
49
+ self.netG.parameters(), lr=config.get("lr", 0.01), betas=(beta1, 0.999)
50
+ )
51
+ with FileLock(os.path.expanduser("~/.data.lock")):
52
+ self.dataloader = get_data_loader(config.get("data_dir", "~/data"))
53
+ self.mnist_model_ref = config["mnist_model_ref"]
54
+
55
+ def step(self):
56
+ lossG, lossD, is_score = train_func(
57
+ self.netD,
58
+ self.netG,
59
+ self.optimizerG,
60
+ self.optimizerD,
61
+ self.criterion,
62
+ self.dataloader,
63
+ self._iteration,
64
+ self.device,
65
+ self.mnist_model_ref,
66
+ )
67
+ return {"lossg": lossG, "lossd": lossD, "is_score": is_score}
68
+
69
+ def save_checkpoint(self, checkpoint_dir):
70
+ path = os.path.join(checkpoint_dir, "checkpoint.pt")
71
+ torch.save(
72
+ {
73
+ "netDmodel": self.netD.state_dict(),
74
+ "netGmodel": self.netG.state_dict(),
75
+ "optimD": self.optimizerD.state_dict(),
76
+ "optimG": self.optimizerG.state_dict(),
77
+ },
78
+ path,
79
+ )
80
+
81
+ return checkpoint_dir
82
+
83
+ def load_checkpoint(self, checkpoint_dir):
84
+ path = os.path.join(checkpoint_dir, "checkpoint.pt")
85
+ checkpoint = torch.load(path)
86
+ self.netD.load_state_dict(checkpoint["netDmodel"])
87
+ self.netG.load_state_dict(checkpoint["netGmodel"])
88
+ self.optimizerD.load_state_dict(checkpoint["optimD"])
89
+ self.optimizerG.load_state_dict(checkpoint["optimG"])
90
+
91
+ def reset_config(self, new_config):
92
+ if "netD_lr" in new_config:
93
+ for param_group in self.optimizerD.param_groups:
94
+ param_group["lr"] = new_config["netD_lr"]
95
+ if "netG_lr" in new_config:
96
+ for param_group in self.optimizerG.param_groups:
97
+ param_group["lr"] = new_config["netG_lr"]
98
+
99
+ self.config = new_config
100
+ return True
101
+
102
+
103
+ # __Trainable_end__
104
+
105
+ if __name__ == "__main__":
106
+ parser = argparse.ArgumentParser()
107
+ parser.add_argument(
108
+ "--smoke-test", action="store_true", help="Finish quickly for testing"
109
+ )
110
+ parser.add_argument(
111
+ "--data-dir", type=str, default="~/data/", help="Set the path of the dataset."
112
+ )
113
+ args, _ = parser.parse_known_args()
114
+ ray.init()
115
+
116
+ import urllib.request
117
+
118
+ # Download a pre-trained MNIST model for inception score calculation.
119
+ # This is a tiny model (<100kb).
120
+ if not os.path.exists(MODEL_PATH):
121
+ print("downloading model")
122
+ os.makedirs(os.path.dirname(MODEL_PATH), exist_ok=True)
123
+ urllib.request.urlretrieve(
124
+ "https://github.com/ray-project/ray/raw/master/python/ray/tune/"
125
+ "examples/pbt_dcgan_mnist/mnist_cnn.pt",
126
+ MODEL_PATH,
127
+ )
128
+
129
+ dataloader = get_data_loader()
130
+ if not args.smoke_test:
131
+ plot_images(dataloader)
132
+
133
+ # load the pretrained mnist classification model for inception_score
134
+ mnist_cnn = Net()
135
+ mnist_cnn.load_state_dict(torch.load(MODEL_PATH))
136
+ mnist_cnn.eval()
137
+ mnist_model_ref = ray.put(mnist_cnn)
138
+
139
+ # __tune_begin__
140
+ scheduler = PopulationBasedTraining(
141
+ time_attr="training_iteration",
142
+ perturbation_interval=5,
143
+ hyperparam_mutations={
144
+ # distribution for resampling
145
+ "netG_lr": lambda: np.random.uniform(1e-2, 1e-5),
146
+ "netD_lr": lambda: np.random.uniform(1e-2, 1e-5),
147
+ },
148
+ )
149
+
150
+ tune_iter = 10 if args.smoke_test else 300
151
+ tuner = tune.Tuner(
152
+ PytorchTrainable,
153
+ run_config=train.RunConfig(
154
+ name="pbt_dcgan_mnist",
155
+ stop={"training_iteration": tune_iter},
156
+ verbose=1,
157
+ checkpoint_config=train.CheckpointConfig(checkpoint_at_end=True),
158
+ ),
159
+ tune_config=tune.TuneConfig(
160
+ metric="is_score",
161
+ mode="max",
162
+ num_samples=8,
163
+ scheduler=scheduler,
164
+ reuse_actors=True,
165
+ ),
166
+ param_space={
167
+ "netG_lr": tune.sample_from(
168
+ lambda spec: random.choice([0.0001, 0.0002, 0.0005])
169
+ ),
170
+ "netD_lr": tune.sample_from(
171
+ lambda spec: random.choice([0.0001, 0.0002, 0.0005])
172
+ ),
173
+ "mnist_model_ref": mnist_model_ref,
174
+ "data_dir": args.data_dir,
175
+ },
176
+ )
177
+ results = tuner.fit()
178
+
179
+ # export_formats=[ExportFormat.MODEL]
180
+ # __tune_end__
181
+
182
+ # demo of the trained Generators
183
+ if not args.smoke_test:
184
+ checkpoint_paths = [result.checkpoint.to_directory() for result in results]
185
+ demo_gan(checkpoint_paths)
.venv/lib/python3.11/site-packages/ray/tune/experimental/__init__.py ADDED
File without changes
.venv/lib/python3.11/site-packages/ray/tune/experimental/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (194 Bytes). View file
 
.venv/lib/python3.11/site-packages/ray/tune/experimental/__pycache__/output.cpython-311.pyc ADDED
Binary file (45.3 kB). View file
 
.venv/lib/python3.11/site-packages/ray/tune/experimental/output.py ADDED
@@ -0,0 +1,1043 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import collections
3
+ import datetime
4
+ import logging
5
+ import math
6
+ import numbers
7
+ import os
8
+ import sys
9
+ import textwrap
10
+ import time
11
+ from dataclasses import dataclass
12
+ from enum import IntEnum
13
+ from typing import Any, Collection, Dict, Iterable, List, Optional, Tuple, Union
14
+
15
+ import numpy as np
16
+ import pandas as pd
17
+
18
+ import ray
19
+ from ray._private.dict import flatten_dict, unflattened_lookup
20
+ from ray._private.thirdparty.tabulate.tabulate import (
21
+ DataRow,
22
+ Line,
23
+ TableFormat,
24
+ tabulate,
25
+ )
26
+ from ray.air._internal.usage import AirEntrypoint
27
+ from ray.air.constants import TRAINING_ITERATION
28
+ from ray.train import Checkpoint
29
+ from ray.tune.callback import Callback
30
+ from ray.tune.experiment.trial import Trial
31
+ from ray.tune.result import (
32
+ AUTO_RESULT_KEYS,
33
+ EPISODE_REWARD_MEAN,
34
+ MEAN_ACCURACY,
35
+ MEAN_LOSS,
36
+ TIME_TOTAL_S,
37
+ TIMESTEPS_TOTAL,
38
+ )
39
+ from ray.tune.search.sample import Domain
40
+ from ray.tune.utils.log import Verbosity
41
+
42
+ try:
43
+ import rich
44
+ import rich.layout
45
+ import rich.live
46
+ except ImportError:
47
+ rich = None
48
+
49
+
50
+ logger = logging.getLogger(__name__)
51
+
52
+ # defines the mapping of the key in result and the key to be printed in table.
53
+ # Note this is ordered!
54
+ DEFAULT_COLUMNS = collections.OrderedDict(
55
+ {
56
+ MEAN_ACCURACY: "acc",
57
+ MEAN_LOSS: "loss",
58
+ TRAINING_ITERATION: "iter",
59
+ TIME_TOTAL_S: "total time (s)",
60
+ TIMESTEPS_TOTAL: "ts",
61
+ EPISODE_REWARD_MEAN: "reward",
62
+ }
63
+ )
64
+
65
+ # These keys are blacklisted for printing out training/tuning intermediate/final result!
66
+ BLACKLISTED_KEYS = {
67
+ "config",
68
+ "date",
69
+ "done",
70
+ "hostname",
71
+ "iterations_since_restore",
72
+ "node_ip",
73
+ "pid",
74
+ "time_since_restore",
75
+ "timestamp",
76
+ "trial_id",
77
+ "experiment_tag",
78
+ "should_checkpoint",
79
+ "_report_on", # LIGHTNING_REPORT_STAGE_KEY
80
+ }
81
+
82
+ VALID_SUMMARY_TYPES = {
83
+ int,
84
+ float,
85
+ np.float32,
86
+ np.float64,
87
+ np.int32,
88
+ np.int64,
89
+ type(None),
90
+ }
91
+
92
+ # The order of summarizing trials.
93
+ ORDER = [
94
+ Trial.RUNNING,
95
+ Trial.TERMINATED,
96
+ Trial.PAUSED,
97
+ Trial.PENDING,
98
+ Trial.ERROR,
99
+ ]
100
+
101
+
102
+ class AirVerbosity(IntEnum):
103
+ SILENT = 0
104
+ DEFAULT = 1
105
+ VERBOSE = 2
106
+
107
+ def __repr__(self):
108
+ return str(self.value)
109
+
110
+
111
+ IS_NOTEBOOK = ray.widgets.util.in_notebook()
112
+
113
+
114
+ def get_air_verbosity(
115
+ verbose: Union[int, AirVerbosity, Verbosity]
116
+ ) -> Optional[AirVerbosity]:
117
+ if os.environ.get("RAY_AIR_NEW_OUTPUT", "1") == "0":
118
+ return None
119
+
120
+ if isinstance(verbose, AirVerbosity):
121
+ return verbose
122
+
123
+ verbose_int = verbose if isinstance(verbose, int) else verbose.value
124
+
125
+ # Verbosity 2 and 3 both map to AirVerbosity 2
126
+ verbose_int = min(2, verbose_int)
127
+
128
+ return AirVerbosity(verbose_int)
129
+
130
+
131
+ def _infer_params(config: Dict[str, Any]) -> List[str]:
132
+ params = []
133
+ flat_config = flatten_dict(config)
134
+ for key, val in flat_config.items():
135
+ if isinstance(val, Domain):
136
+ params.append(key)
137
+ # Grid search is a special named field. Because we flattened
138
+ # the whole config, we look it up per string
139
+ if key.endswith("/grid_search"):
140
+ # Truncate `/grid_search`
141
+ params.append(key[:-12])
142
+ return params
143
+
144
+
145
+ def _get_time_str(start_time: float, current_time: float) -> Tuple[str, str]:
146
+ """Get strings representing the current and elapsed time.
147
+
148
+ Args:
149
+ start_time: POSIX timestamp of the start of the tune run
150
+ current_time: POSIX timestamp giving the current time
151
+
152
+ Returns:
153
+ Current time and elapsed time for the current run
154
+ """
155
+ current_time_dt = datetime.datetime.fromtimestamp(current_time)
156
+ start_time_dt = datetime.datetime.fromtimestamp(start_time)
157
+ delta: datetime.timedelta = current_time_dt - start_time_dt
158
+
159
+ rest = delta.total_seconds()
160
+ days = int(rest // (60 * 60 * 24))
161
+
162
+ rest -= days * (60 * 60 * 24)
163
+ hours = int(rest // (60 * 60))
164
+
165
+ rest -= hours * (60 * 60)
166
+ minutes = int(rest // 60)
167
+
168
+ seconds = int(rest - minutes * 60)
169
+
170
+ running_for_str = ""
171
+ if days > 0:
172
+ running_for_str += f"{days:d}d "
173
+
174
+ if hours > 0 or running_for_str:
175
+ running_for_str += f"{hours:d}hr "
176
+
177
+ if minutes > 0 or running_for_str:
178
+ running_for_str += f"{minutes:d}min "
179
+
180
+ running_for_str += f"{seconds:d}s"
181
+
182
+ return f"{current_time_dt:%Y-%m-%d %H:%M:%S}", running_for_str
183
+
184
+
185
+ def _get_trials_by_state(trials: List[Trial]) -> Dict[str, List[Trial]]:
186
+ trials_by_state = collections.defaultdict(list)
187
+ for t in trials:
188
+ trials_by_state[t.status].append(t)
189
+ return trials_by_state
190
+
191
+
192
+ def _get_trials_with_error(trials: List[Trial]) -> List[Trial]:
193
+ return [t for t in trials if t.error_file]
194
+
195
+
196
+ def _infer_user_metrics(trials: List[Trial], limit: int = 4) -> List[str]:
197
+ """Try to infer the metrics to print out.
198
+
199
+ By default, only the first 4 meaningful metrics in `last_result` will be
200
+ inferred as user implied metrics.
201
+ """
202
+ # Using OrderedDict for OrderedSet.
203
+ result = collections.OrderedDict()
204
+ for t in trials:
205
+ if not t.last_result:
206
+ continue
207
+ for metric, value in t.last_result.items():
208
+ if metric not in DEFAULT_COLUMNS:
209
+ if metric not in AUTO_RESULT_KEYS:
210
+ if type(value) in VALID_SUMMARY_TYPES:
211
+ result[metric] = "" # not important
212
+
213
+ if len(result) >= limit:
214
+ return list(result.keys())
215
+ return list(result.keys())
216
+
217
+
218
+ def _current_best_trial(
219
+ trials: List[Trial], metric: Optional[str], mode: Optional[str]
220
+ ) -> Tuple[Optional[Trial], Optional[str]]:
221
+ """
222
+ Returns the best trial and the metric key. If anything is empty or None,
223
+ returns a trivial result of None, None.
224
+
225
+ Args:
226
+ trials: List of trials.
227
+ metric: Metric that trials are being ranked.
228
+ mode: One of "min" or "max".
229
+
230
+ Returns:
231
+ Best trial and the metric key.
232
+ """
233
+ if not trials or not metric or not mode:
234
+ return None, None
235
+
236
+ metric_op = 1.0 if mode == "max" else -1.0
237
+ best_metric = float("-inf")
238
+ best_trial = None
239
+ for t in trials:
240
+ if not t.last_result:
241
+ continue
242
+ metric_value = unflattened_lookup(metric, t.last_result, default=None)
243
+ if pd.isnull(metric_value):
244
+ continue
245
+ if not best_trial or metric_value * metric_op > best_metric:
246
+ best_metric = metric_value * metric_op
247
+ best_trial = t
248
+ return best_trial, metric
249
+
250
+
251
+ @dataclass
252
+ class _PerStatusTrialTableData:
253
+ trial_infos: List[List[str]]
254
+ more_info: str
255
+
256
+
257
+ @dataclass
258
+ class _TrialTableData:
259
+ header: List[str]
260
+ data: List[_PerStatusTrialTableData]
261
+
262
+
263
+ def _max_len(value: Any, max_len: int = 20, wrap: bool = False) -> Any:
264
+ """Abbreviate a string representation of an object to `max_len` characters.
265
+
266
+ For numbers, booleans and None, the original value will be returned for
267
+ correct rendering in the table formatting tool.
268
+
269
+ Args:
270
+ value: Object to be represented as a string.
271
+ max_len: Maximum return string length.
272
+ """
273
+ if value is None or isinstance(value, (int, float, numbers.Number, bool)):
274
+ return value
275
+
276
+ string = str(value)
277
+ if len(string) <= max_len:
278
+ return string
279
+
280
+ if wrap:
281
+ # Maximum two rows.
282
+ # Todo: Make this configurable in the refactor
283
+ if len(value) > max_len * 2:
284
+ value = "..." + string[(3 - (max_len * 2)) :]
285
+
286
+ wrapped = textwrap.wrap(value, width=max_len)
287
+ return "\n".join(wrapped)
288
+
289
+ result = "..." + string[(3 - max_len) :]
290
+ return result
291
+
292
+
293
+ def _get_trial_info(
294
+ trial: Trial, param_keys: List[str], metric_keys: List[str]
295
+ ) -> List[str]:
296
+ """Returns the following information about a trial:
297
+
298
+ name | status | metrics...
299
+
300
+ Args:
301
+ trial: Trial to get information for.
302
+ param_keys: Names of parameters to include.
303
+ metric_keys: Names of metrics to include.
304
+ """
305
+ result = trial.last_result
306
+ trial_info = [str(trial), trial.status]
307
+
308
+ # params
309
+ trial_info.extend(
310
+ [
311
+ _max_len(
312
+ unflattened_lookup(param, trial.config, default=None),
313
+ )
314
+ for param in param_keys
315
+ ]
316
+ )
317
+ # metrics
318
+ trial_info.extend(
319
+ [
320
+ _max_len(
321
+ unflattened_lookup(metric, result, default=None),
322
+ )
323
+ for metric in metric_keys
324
+ ]
325
+ )
326
+ return trial_info
327
+
328
+
329
+ def _get_trial_table_data_per_status(
330
+ status: str,
331
+ trials: List[Trial],
332
+ param_keys: List[str],
333
+ metric_keys: List[str],
334
+ force_max_rows: bool = False,
335
+ ) -> Optional[_PerStatusTrialTableData]:
336
+ """Gather all information of trials pertained to one `status`.
337
+
338
+ Args:
339
+ status: The trial status of interest.
340
+ trials: all the trials of that status.
341
+ param_keys: *Ordered* list of parameters to be displayed in the table.
342
+ metric_keys: *Ordered* list of metrics to be displayed in the table.
343
+ Including both default and user defined.
344
+ force_max_rows: Whether or not to enforce a max row number for this status.
345
+ If True, only a max of `5` rows will be shown.
346
+
347
+ Returns:
348
+ All information of trials pertained to the `status`.
349
+ """
350
+ # TODO: configure it.
351
+ max_row = 5 if force_max_rows else math.inf
352
+ if not trials:
353
+ return None
354
+
355
+ trial_infos = list()
356
+ more_info = None
357
+ for t in trials:
358
+ if len(trial_infos) >= max_row:
359
+ remaining = len(trials) - max_row
360
+ more_info = f"{remaining} more {status}"
361
+ break
362
+ trial_infos.append(_get_trial_info(t, param_keys, metric_keys))
363
+ return _PerStatusTrialTableData(trial_infos, more_info)
364
+
365
+
366
+ def _get_trial_table_data(
367
+ trials: List[Trial],
368
+ param_keys: List[str],
369
+ metric_keys: List[str],
370
+ all_rows: bool = False,
371
+ wrap_headers: bool = False,
372
+ ) -> _TrialTableData:
373
+ """Generate a table showing the current progress of tuning trials.
374
+
375
+ Args:
376
+ trials: List of trials for which progress is to be shown.
377
+ param_keys: Ordered list of parameters to be displayed in the table.
378
+ metric_keys: Ordered list of metrics to be displayed in the table.
379
+ Including both default and user defined.
380
+ Will only be shown if at least one trial is having the key.
381
+ all_rows: Force to show all rows.
382
+ wrap_headers: If True, header columns can be wrapped with ``\n``.
383
+
384
+ Returns:
385
+ Trial table data, including header and trial table per each status.
386
+ """
387
+ # TODO: configure
388
+ max_trial_num_to_show = 20
389
+ max_column_length = 20
390
+ trials_by_state = _get_trials_by_state(trials)
391
+
392
+ # get the right metric to show.
393
+ metric_keys = [
394
+ k
395
+ for k in metric_keys
396
+ if any(
397
+ unflattened_lookup(k, t.last_result, default=None) is not None
398
+ for t in trials
399
+ )
400
+ ]
401
+
402
+ # get header from metric keys
403
+ formatted_metric_columns = [
404
+ _max_len(k, max_len=max_column_length, wrap=wrap_headers) for k in metric_keys
405
+ ]
406
+
407
+ formatted_param_columns = [
408
+ _max_len(k, max_len=max_column_length, wrap=wrap_headers) for k in param_keys
409
+ ]
410
+
411
+ metric_header = [
412
+ DEFAULT_COLUMNS[metric] if metric in DEFAULT_COLUMNS else formatted
413
+ for metric, formatted in zip(metric_keys, formatted_metric_columns)
414
+ ]
415
+
416
+ param_header = formatted_param_columns
417
+
418
+ # Map to the abbreviated version if necessary.
419
+ header = ["Trial name", "status"] + param_header + metric_header
420
+
421
+ trial_data = list()
422
+ for t_status in ORDER:
423
+ trial_data_per_status = _get_trial_table_data_per_status(
424
+ t_status,
425
+ trials_by_state[t_status],
426
+ param_keys=param_keys,
427
+ metric_keys=metric_keys,
428
+ force_max_rows=not all_rows and len(trials) > max_trial_num_to_show,
429
+ )
430
+ if trial_data_per_status:
431
+ trial_data.append(trial_data_per_status)
432
+ return _TrialTableData(header, trial_data)
433
+
434
+
435
+ def _best_trial_str(
436
+ trial: Trial,
437
+ metric: str,
438
+ ):
439
+ """Returns a readable message stating the current best trial."""
440
+ # returns something like
441
+ # Current best trial: 18ae7_00005 with loss=0.5918508041056858 and params={'train_loop_config': {'lr': 0.059253447253394785}}. # noqa
442
+ val = unflattened_lookup(metric, trial.last_result, default=None)
443
+ config = trial.last_result.get("config", {})
444
+ parameter_columns = list(config.keys())
445
+ params = {p: unflattened_lookup(p, config) for p in parameter_columns}
446
+ return (
447
+ f"Current best trial: {trial.trial_id} with {metric}={val} and "
448
+ f"params={params}"
449
+ )
450
+
451
+
452
+ def _render_table_item(
453
+ key: str, item: Any, prefix: str = ""
454
+ ) -> Iterable[Tuple[str, str]]:
455
+ key = prefix + key
456
+
457
+ if isinstance(item, argparse.Namespace):
458
+ item = item.__dict__
459
+
460
+ if isinstance(item, float):
461
+ # tabulate does not work well with mixed-type columns, so we format
462
+ # numbers ourselves.
463
+ yield key, f"{item:.5f}".rstrip("0")
464
+ elif isinstance(item, dict):
465
+ flattened = flatten_dict(item)
466
+ for k, v in sorted(flattened.items()):
467
+ yield key + "/" + str(k), _max_len(v)
468
+ else:
469
+ yield key, _max_len(item, 20)
470
+
471
+
472
+ def _get_dict_as_table_data(
473
+ data: Dict,
474
+ include: Optional[Collection] = None,
475
+ exclude: Optional[Collection] = None,
476
+ upper_keys: Optional[Collection] = None,
477
+ ):
478
+ """Get ``data`` dict as table rows.
479
+
480
+ If specified, excluded keys are removed. Excluded keys can either be
481
+ fully specified (e.g. ``foo/bar/baz``) or specify a top-level dictionary
482
+ (e.g. ``foo``), but no intermediate levels (e.g. ``foo/bar``). If this is
483
+ needed, we can revisit the logic at a later point.
484
+
485
+ The same is true for included keys. If a top-level key is included (e.g. ``foo``)
486
+ then all sub keys will be included, too, except if they are excluded.
487
+
488
+ If keys are both excluded and included, exclusion takes precedence. Thus, if
489
+ ``foo`` is excluded but ``foo/bar`` is included, it won't show up in the output.
490
+ """
491
+ include = include or set()
492
+ exclude = exclude or set()
493
+ upper_keys = upper_keys or set()
494
+
495
+ upper = []
496
+ lower = []
497
+
498
+ for key, value in sorted(data.items()):
499
+ # Exclude top-level keys
500
+ if key in exclude:
501
+ continue
502
+
503
+ for k, v in _render_table_item(str(key), value):
504
+ # k is now the full subkey, e.g. config/nested/key
505
+
506
+ # We can exclude the full key
507
+ if k in exclude:
508
+ continue
509
+
510
+ # If we specify includes, top-level includes should take precedence
511
+ # (e.g. if `config` is in include, include config always).
512
+ if include and key not in include and k not in include:
513
+ continue
514
+
515
+ if key in upper_keys:
516
+ upper.append([k, v])
517
+ else:
518
+ lower.append([k, v])
519
+
520
+ if not upper:
521
+ return lower
522
+ elif not lower:
523
+ return upper
524
+ else:
525
+ return upper + lower
526
+
527
+
528
+ if sys.stdout and sys.stdout.encoding and sys.stdout.encoding.startswith("utf"):
529
+ # Copied/adjusted from tabulate
530
+ AIR_TABULATE_TABLEFMT = TableFormat(
531
+ lineabove=Line("╭", "─", "─", "╮"),
532
+ linebelowheader=Line("├", "─", "─", "┤"),
533
+ linebetweenrows=None,
534
+ linebelow=Line("╰", "─", "─", "╯"),
535
+ headerrow=DataRow("│", " ", "│"),
536
+ datarow=DataRow("│", " ", "│"),
537
+ padding=1,
538
+ with_header_hide=None,
539
+ )
540
+ else:
541
+ # For non-utf output, use ascii-compatible characters.
542
+ # This prevents errors e.g. when legacy windows encoding is used.
543
+ AIR_TABULATE_TABLEFMT = TableFormat(
544
+ lineabove=Line("+", "-", "-", "+"),
545
+ linebelowheader=Line("+", "-", "-", "+"),
546
+ linebetweenrows=None,
547
+ linebelow=Line("+", "-", "-", "+"),
548
+ headerrow=DataRow("|", " ", "|"),
549
+ datarow=DataRow("|", " ", "|"),
550
+ padding=1,
551
+ with_header_hide=None,
552
+ )
553
+
554
+
555
+ def _print_dict_as_table(
556
+ data: Dict,
557
+ header: Optional[str] = None,
558
+ include: Optional[Collection[str]] = None,
559
+ exclude: Optional[Collection[str]] = None,
560
+ division: Optional[Collection[str]] = None,
561
+ ):
562
+ table_data = _get_dict_as_table_data(
563
+ data=data, include=include, exclude=exclude, upper_keys=division
564
+ )
565
+
566
+ headers = [header, ""] if header else []
567
+
568
+ if not table_data:
569
+ return
570
+
571
+ print(
572
+ tabulate(
573
+ table_data,
574
+ headers=headers,
575
+ colalign=("left", "right"),
576
+ tablefmt=AIR_TABULATE_TABLEFMT,
577
+ )
578
+ )
579
+
580
+
581
+ class ProgressReporter(Callback):
582
+ """Periodically prints out status update."""
583
+
584
+ # TODO: Make this configurable
585
+ _heartbeat_freq = 30 # every 30 sec
586
+ # to be updated by subclasses.
587
+ _heartbeat_threshold = None
588
+ _start_end_verbosity = None
589
+ _intermediate_result_verbosity = None
590
+ _addressing_tmpl = None
591
+
592
+ def __init__(
593
+ self,
594
+ verbosity: AirVerbosity,
595
+ progress_metrics: Optional[Union[List[str], List[Dict[str, str]]]] = None,
596
+ ):
597
+ """
598
+
599
+ Args:
600
+ verbosity: AirVerbosity level.
601
+ """
602
+ self._verbosity = verbosity
603
+ self._start_time = time.time()
604
+ self._last_heartbeat_time = float("-inf")
605
+ self._start_time = time.time()
606
+ self._progress_metrics = progress_metrics
607
+ self._trial_last_printed_results = {}
608
+
609
+ self._in_block = None
610
+
611
+ @property
612
+ def verbosity(self) -> AirVerbosity:
613
+ return self._verbosity
614
+
615
+ def setup(
616
+ self,
617
+ start_time: Optional[float] = None,
618
+ **kwargs,
619
+ ):
620
+ self._start_time = start_time
621
+
622
+ def _start_block(self, indicator: Any):
623
+ if self._in_block != indicator:
624
+ self._end_block()
625
+ self._in_block = indicator
626
+
627
+ def _end_block(self):
628
+ if self._in_block:
629
+ print("")
630
+ self._in_block = None
631
+
632
+ def on_experiment_end(self, trials: List["Trial"], **info):
633
+ self._end_block()
634
+
635
+ def experiment_started(
636
+ self,
637
+ experiment_name: str,
638
+ experiment_path: str,
639
+ searcher_str: str,
640
+ scheduler_str: str,
641
+ total_num_samples: int,
642
+ tensorboard_path: Optional[str] = None,
643
+ **kwargs,
644
+ ):
645
+ self._start_block("exp_start")
646
+ print(f"\nView detailed results here: {experiment_path}")
647
+
648
+ if tensorboard_path:
649
+ print(
650
+ f"To visualize your results with TensorBoard, run: "
651
+ f"`tensorboard --logdir {tensorboard_path}`"
652
+ )
653
+
654
+ @property
655
+ def _time_heartbeat_str(self):
656
+ current_time_str, running_time_str = _get_time_str(
657
+ self._start_time, time.time()
658
+ )
659
+ return (
660
+ f"Current time: {current_time_str}. Total running time: " + running_time_str
661
+ )
662
+
663
+ def print_heartbeat(self, trials, *args, force: bool = False):
664
+ if self._verbosity < self._heartbeat_threshold:
665
+ return
666
+ if force or time.time() - self._last_heartbeat_time >= self._heartbeat_freq:
667
+ self._print_heartbeat(trials, *args, force=force)
668
+ self._last_heartbeat_time = time.time()
669
+
670
+ def _print_heartbeat(self, trials, *args, force: bool = False):
671
+ raise NotImplementedError
672
+
673
+ def _print_result(self, trial, result: Optional[Dict] = None, force: bool = False):
674
+ """Only print result if a different result has been reported, or force=True"""
675
+ result = result or trial.last_result
676
+
677
+ last_result_iter = self._trial_last_printed_results.get(trial.trial_id, -1)
678
+ this_iter = result.get(TRAINING_ITERATION, 0)
679
+
680
+ if this_iter != last_result_iter or force:
681
+ _print_dict_as_table(
682
+ result,
683
+ header=f"{self._addressing_tmpl.format(trial)} result",
684
+ include=self._progress_metrics,
685
+ exclude=BLACKLISTED_KEYS,
686
+ division=AUTO_RESULT_KEYS,
687
+ )
688
+ self._trial_last_printed_results[trial.trial_id] = this_iter
689
+
690
+ def _print_config(self, trial):
691
+ _print_dict_as_table(
692
+ trial.config, header=f"{self._addressing_tmpl.format(trial)} config"
693
+ )
694
+
695
+ def on_trial_result(
696
+ self,
697
+ iteration: int,
698
+ trials: List[Trial],
699
+ trial: Trial,
700
+ result: Dict,
701
+ **info,
702
+ ):
703
+ if self.verbosity < self._intermediate_result_verbosity:
704
+ return
705
+ self._start_block(f"trial_{trial}_result_{result[TRAINING_ITERATION]}")
706
+ curr_time_str, running_time_str = _get_time_str(self._start_time, time.time())
707
+ print(
708
+ f"{self._addressing_tmpl.format(trial)} "
709
+ f"finished iteration {result[TRAINING_ITERATION]} "
710
+ f"at {curr_time_str}. Total running time: " + running_time_str
711
+ )
712
+ self._print_result(trial, result)
713
+
714
+ def on_trial_complete(
715
+ self, iteration: int, trials: List[Trial], trial: Trial, **info
716
+ ):
717
+ if self.verbosity < self._start_end_verbosity:
718
+ return
719
+ curr_time_str, running_time_str = _get_time_str(self._start_time, time.time())
720
+ finished_iter = 0
721
+ if trial.last_result and TRAINING_ITERATION in trial.last_result:
722
+ finished_iter = trial.last_result[TRAINING_ITERATION]
723
+
724
+ self._start_block(f"trial_{trial}_complete")
725
+ print(
726
+ f"{self._addressing_tmpl.format(trial)} "
727
+ f"completed after {finished_iter} iterations "
728
+ f"at {curr_time_str}. Total running time: " + running_time_str
729
+ )
730
+ self._print_result(trial)
731
+
732
+ def on_trial_error(
733
+ self, iteration: int, trials: List["Trial"], trial: "Trial", **info
734
+ ):
735
+ curr_time_str, running_time_str = _get_time_str(self._start_time, time.time())
736
+ finished_iter = 0
737
+ if trial.last_result and TRAINING_ITERATION in trial.last_result:
738
+ finished_iter = trial.last_result[TRAINING_ITERATION]
739
+
740
+ self._start_block(f"trial_{trial}_error")
741
+ print(
742
+ f"{self._addressing_tmpl.format(trial)} "
743
+ f"errored after {finished_iter} iterations "
744
+ f"at {curr_time_str}. Total running time: {running_time_str}\n"
745
+ f"Error file: {trial.error_file}"
746
+ )
747
+ self._print_result(trial)
748
+
749
+ def on_trial_recover(
750
+ self, iteration: int, trials: List["Trial"], trial: "Trial", **info
751
+ ):
752
+ self.on_trial_error(iteration=iteration, trials=trials, trial=trial, **info)
753
+
754
+ def on_checkpoint(
755
+ self,
756
+ iteration: int,
757
+ trials: List[Trial],
758
+ trial: Trial,
759
+ checkpoint: Checkpoint,
760
+ **info,
761
+ ):
762
+ if self._verbosity < self._intermediate_result_verbosity:
763
+ return
764
+ # don't think this is supposed to happen but just to be safe.
765
+ saved_iter = "?"
766
+ if trial.last_result and TRAINING_ITERATION in trial.last_result:
767
+ saved_iter = trial.last_result[TRAINING_ITERATION]
768
+
769
+ self._start_block(f"trial_{trial}_result_{saved_iter}")
770
+
771
+ loc = f"({checkpoint.filesystem.type_name}){checkpoint.path}"
772
+
773
+ print(
774
+ f"{self._addressing_tmpl.format(trial)} "
775
+ f"saved a checkpoint for iteration {saved_iter} "
776
+ f"at: {loc}"
777
+ )
778
+
779
+ def on_trial_start(self, iteration: int, trials: List[Trial], trial: Trial, **info):
780
+ if self.verbosity < self._start_end_verbosity:
781
+ return
782
+ has_config = bool(trial.config)
783
+
784
+ self._start_block(f"trial_{trial}_start")
785
+ if has_config:
786
+ print(
787
+ f"{self._addressing_tmpl.format(trial)} " f"started with configuration:"
788
+ )
789
+ self._print_config(trial)
790
+ else:
791
+ print(
792
+ f"{self._addressing_tmpl.format(trial)} "
793
+ f"started without custom configuration."
794
+ )
795
+
796
+
797
+ def _detect_reporter(
798
+ verbosity: AirVerbosity,
799
+ num_samples: int,
800
+ entrypoint: Optional[AirEntrypoint] = None,
801
+ metric: Optional[str] = None,
802
+ mode: Optional[str] = None,
803
+ config: Optional[Dict] = None,
804
+ progress_metrics: Optional[Union[List[str], List[Dict[str, str]]]] = None,
805
+ ):
806
+ if entrypoint in {
807
+ AirEntrypoint.TUNE_RUN,
808
+ AirEntrypoint.TUNE_RUN_EXPERIMENTS,
809
+ AirEntrypoint.TUNER,
810
+ }:
811
+ reporter = TuneTerminalReporter(
812
+ verbosity,
813
+ num_samples=num_samples,
814
+ metric=metric,
815
+ mode=mode,
816
+ config=config,
817
+ progress_metrics=progress_metrics,
818
+ )
819
+ else:
820
+ reporter = TrainReporter(verbosity, progress_metrics=progress_metrics)
821
+ return reporter
822
+
823
+
824
+ class TuneReporterBase(ProgressReporter):
825
+ _heartbeat_threshold = AirVerbosity.DEFAULT
826
+ _wrap_headers = False
827
+ _intermediate_result_verbosity = AirVerbosity.VERBOSE
828
+ _start_end_verbosity = AirVerbosity.DEFAULT
829
+ _addressing_tmpl = "Trial {}"
830
+
831
+ def __init__(
832
+ self,
833
+ verbosity: AirVerbosity,
834
+ num_samples: int = 0,
835
+ metric: Optional[str] = None,
836
+ mode: Optional[str] = None,
837
+ config: Optional[Dict] = None,
838
+ progress_metrics: Optional[Union[List[str], List[Dict[str, str]]]] = None,
839
+ ):
840
+ self._num_samples = num_samples
841
+ self._metric = metric
842
+ self._mode = mode
843
+ # will be populated when first result comes in.
844
+ self._inferred_metric = None
845
+ self._inferred_params = _infer_params(config or {})
846
+ super(TuneReporterBase, self).__init__(
847
+ verbosity=verbosity, progress_metrics=progress_metrics
848
+ )
849
+
850
+ def setup(
851
+ self,
852
+ start_time: Optional[float] = None,
853
+ total_samples: Optional[int] = None,
854
+ **kwargs,
855
+ ):
856
+ super().setup(start_time=start_time)
857
+ self._num_samples = total_samples
858
+
859
+ def _get_overall_trial_progress_str(self, trials):
860
+ result = " | ".join(
861
+ [
862
+ f"{len(trials)} {status}"
863
+ for status, trials in _get_trials_by_state(trials).items()
864
+ ]
865
+ )
866
+ return f"Trial status: {result}"
867
+
868
+ # TODO: Return a more structured type to share code with Jupyter flow.
869
+ def _get_heartbeat(
870
+ self, trials, *sys_args, force_full_output: bool = False
871
+ ) -> Tuple[List[str], _TrialTableData]:
872
+ result = list()
873
+ # Trial status: 1 RUNNING | 7 PENDING
874
+ result.append(self._get_overall_trial_progress_str(trials))
875
+ # Current time: 2023-02-24 12:35:39 (running for 00:00:37.40)
876
+ result.append(self._time_heartbeat_str)
877
+ # Logical resource usage: 8.0/64 CPUs, 0/0 GPUs
878
+ result.extend(sys_args)
879
+ # Current best trial: TRIAL NAME, metrics: {...}, parameters: {...}
880
+ current_best_trial, metric = _current_best_trial(
881
+ trials, self._metric, self._mode
882
+ )
883
+ if current_best_trial:
884
+ result.append(_best_trial_str(current_best_trial, metric))
885
+ # Now populating the trial table data.
886
+ if not self._inferred_metric:
887
+ # try inferring again.
888
+ self._inferred_metric = _infer_user_metrics(trials)
889
+
890
+ all_metrics = list(DEFAULT_COLUMNS.keys()) + self._inferred_metric
891
+
892
+ trial_table_data = _get_trial_table_data(
893
+ trials,
894
+ param_keys=self._inferred_params,
895
+ metric_keys=all_metrics,
896
+ all_rows=force_full_output,
897
+ wrap_headers=self._wrap_headers,
898
+ )
899
+ return result, trial_table_data
900
+
901
+ def _print_heartbeat(self, trials, *sys_args, force: bool = False):
902
+ raise NotImplementedError
903
+
904
+
905
+ class TuneTerminalReporter(TuneReporterBase):
906
+ def experiment_started(
907
+ self,
908
+ experiment_name: str,
909
+ experiment_path: str,
910
+ searcher_str: str,
911
+ scheduler_str: str,
912
+ total_num_samples: int,
913
+ tensorboard_path: Optional[str] = None,
914
+ **kwargs,
915
+ ):
916
+ if total_num_samples > sys.maxsize:
917
+ total_num_samples_str = "infinite"
918
+ else:
919
+ total_num_samples_str = str(total_num_samples)
920
+
921
+ print(
922
+ tabulate(
923
+ [
924
+ ["Search algorithm", searcher_str],
925
+ ["Scheduler", scheduler_str],
926
+ ["Number of trials", total_num_samples_str],
927
+ ],
928
+ headers=["Configuration for experiment", experiment_name],
929
+ tablefmt=AIR_TABULATE_TABLEFMT,
930
+ )
931
+ )
932
+ super().experiment_started(
933
+ experiment_name=experiment_name,
934
+ experiment_path=experiment_path,
935
+ searcher_str=searcher_str,
936
+ scheduler_str=scheduler_str,
937
+ total_num_samples=total_num_samples,
938
+ tensorboard_path=tensorboard_path,
939
+ **kwargs,
940
+ )
941
+
942
+ def _print_heartbeat(self, trials, *sys_args, force: bool = False):
943
+ if self._verbosity < self._heartbeat_threshold and not force:
944
+ return
945
+ heartbeat_strs, table_data = self._get_heartbeat(
946
+ trials, *sys_args, force_full_output=force
947
+ )
948
+
949
+ self._start_block("heartbeat")
950
+ for s in heartbeat_strs:
951
+ print(s)
952
+ # now print the table using Tabulate
953
+ more_infos = []
954
+ all_data = []
955
+ fail_header = table_data.header
956
+ for sub_table in table_data.data:
957
+ all_data.extend(sub_table.trial_infos)
958
+ if sub_table.more_info:
959
+ more_infos.append(sub_table.more_info)
960
+
961
+ print(
962
+ tabulate(
963
+ all_data,
964
+ headers=fail_header,
965
+ tablefmt=AIR_TABULATE_TABLEFMT,
966
+ showindex=False,
967
+ )
968
+ )
969
+ if more_infos:
970
+ print(", ".join(more_infos))
971
+
972
+ if not force:
973
+ # Only print error table at end of training
974
+ return
975
+
976
+ trials_with_error = _get_trials_with_error(trials)
977
+ if not trials_with_error:
978
+ return
979
+
980
+ self._start_block("status_errored")
981
+ print(f"Number of errored trials: {len(trials_with_error)}")
982
+ fail_header = ["Trial name", "# failures", "error file"]
983
+ fail_table_data = [
984
+ [
985
+ str(trial),
986
+ str(trial.run_metadata.num_failures)
987
+ + ("" if trial.status == Trial.ERROR else "*"),
988
+ trial.error_file,
989
+ ]
990
+ for trial in trials_with_error
991
+ ]
992
+ print(
993
+ tabulate(
994
+ fail_table_data,
995
+ headers=fail_header,
996
+ tablefmt=AIR_TABULATE_TABLEFMT,
997
+ showindex=False,
998
+ colalign=("left", "right", "left"),
999
+ )
1000
+ )
1001
+ if any(trial.status == Trial.TERMINATED for trial in trials_with_error):
1002
+ print("* The trial terminated successfully after retrying.")
1003
+
1004
+
1005
+ class TrainReporter(ProgressReporter):
1006
+ # the minimal verbosity threshold at which heartbeat starts getting printed.
1007
+ _heartbeat_threshold = AirVerbosity.VERBOSE
1008
+ _intermediate_result_verbosity = AirVerbosity.DEFAULT
1009
+ _start_end_verbosity = AirVerbosity.DEFAULT
1010
+ _addressing_tmpl = "Training"
1011
+
1012
+ def _get_heartbeat(self, trials: List[Trial], force_full_output: bool = False):
1013
+ # Training on iteration 1. Current time: 2023-03-22 15:29:25 (running for 00:00:03.24) # noqa
1014
+ if len(trials) == 0:
1015
+ return
1016
+ trial = trials[0]
1017
+ if trial.status != Trial.RUNNING:
1018
+ return " ".join(
1019
+ [f"Training is in {trial.status} status.", self._time_heartbeat_str]
1020
+ )
1021
+ if not trial.last_result or TRAINING_ITERATION not in trial.last_result:
1022
+ iter_num = 1
1023
+ else:
1024
+ iter_num = trial.last_result[TRAINING_ITERATION] + 1
1025
+ return " ".join(
1026
+ [f"Training on iteration {iter_num}.", self._time_heartbeat_str]
1027
+ )
1028
+
1029
+ def _print_heartbeat(self, trials, *args, force: bool = False):
1030
+ print(self._get_heartbeat(trials, force_full_output=force))
1031
+
1032
+ def on_trial_result(
1033
+ self,
1034
+ iteration: int,
1035
+ trials: List[Trial],
1036
+ trial: Trial,
1037
+ result: Dict,
1038
+ **info,
1039
+ ):
1040
+ self._last_heartbeat_time = time.time()
1041
+ super().on_trial_result(
1042
+ iteration=iteration, trials=trials, trial=trial, result=result, **info
1043
+ )
.venv/lib/python3.11/site-packages/ray/tune/logger/__init__.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ray.tune.logger.csv import CSVLogger, CSVLoggerCallback
2
+ from ray.tune.logger.json import JsonLogger, JsonLoggerCallback
3
+ from ray.tune.logger.logger import (
4
+ LegacyLoggerCallback,
5
+ Logger,
6
+ LoggerCallback,
7
+ pretty_print,
8
+ )
9
+ from ray.tune.logger.noop import NoopLogger
10
+ from ray.tune.logger.tensorboardx import TBXLogger, TBXLoggerCallback
11
+
12
+ DEFAULT_LOGGERS = (JsonLogger, CSVLogger, TBXLogger)
13
+
14
+ # isort: off
15
+ from ray.tune.logger.unified import UnifiedLogger # noqa: E402
16
+
17
+ # isort: on
18
+
19
+ __all__ = [
20
+ "Logger",
21
+ "LoggerCallback",
22
+ "LegacyLoggerCallback",
23
+ "pretty_print",
24
+ "CSVLogger",
25
+ "CSVLoggerCallback",
26
+ "JsonLogger",
27
+ "JsonLoggerCallback",
28
+ "NoopLogger",
29
+ "TBXLogger",
30
+ "TBXLoggerCallback",
31
+ "UnifiedLogger",
32
+ ]
.venv/lib/python3.11/site-packages/ray/tune/logger/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (982 Bytes). View file
 
.venv/lib/python3.11/site-packages/ray/tune/logger/__pycache__/aim.cpython-311.pyc ADDED
Binary file (10.1 kB). View file
 
.venv/lib/python3.11/site-packages/ray/tune/logger/__pycache__/comet.cpython-311.pyc ADDED
Binary file (327 Bytes). View file
 
.venv/lib/python3.11/site-packages/ray/tune/logger/__pycache__/csv.cpython-311.pyc ADDED
Binary file (7.66 kB). View file
 
.venv/lib/python3.11/site-packages/ray/tune/logger/__pycache__/json.cpython-311.pyc ADDED
Binary file (8.3 kB). View file
 
.venv/lib/python3.11/site-packages/ray/tune/logger/__pycache__/logger.cpython-311.pyc ADDED
Binary file (13.6 kB). View file
 
.venv/lib/python3.11/site-packages/ray/tune/logger/__pycache__/mlflow.cpython-311.pyc ADDED
Binary file (331 Bytes). View file
 
.venv/lib/python3.11/site-packages/ray/tune/logger/__pycache__/noop.cpython-311.pyc ADDED
Binary file (875 Bytes). View file
 
.venv/lib/python3.11/site-packages/ray/tune/logger/__pycache__/tensorboardx.cpython-311.pyc ADDED
Binary file (17.3 kB). View file
 
.venv/lib/python3.11/site-packages/ray/tune/logger/__pycache__/unified.cpython-311.pyc ADDED
Binary file (4.34 kB). View file
 
.venv/lib/python3.11/site-packages/ray/tune/logger/__pycache__/wandb.cpython-311.pyc ADDED
Binary file (327 Bytes). View file
 
.venv/lib/python3.11/site-packages/ray/tune/logger/aim.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import TYPE_CHECKING, Dict, List, Optional, Union
3
+
4
+ import numpy as np
5
+
6
+ from ray.air.constants import TRAINING_ITERATION
7
+ from ray.tune.logger.logger import LoggerCallback
8
+ from ray.tune.result import TIME_TOTAL_S, TIMESTEPS_TOTAL
9
+ from ray.tune.utils import flatten_dict
10
+ from ray.util.annotations import PublicAPI
11
+
12
+ if TYPE_CHECKING:
13
+ from ray.tune.experiment.trial import Trial
14
+
15
+ try:
16
+ from aim.sdk import Repo, Run
17
+ except ImportError:
18
+ Repo, Run = None, None
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+ VALID_SUMMARY_TYPES = [int, float, np.float32, np.float64, np.int32, np.int64]
23
+
24
+
25
+ @PublicAPI
26
+ class AimLoggerCallback(LoggerCallback):
27
+ """Aim Logger: logs metrics in Aim format.
28
+
29
+ Aim is an open-source, self-hosted ML experiment tracking tool.
30
+ It's good at tracking lots (thousands) of training runs, and it allows you to
31
+ compare them with a performant and well-designed UI.
32
+
33
+ Source: https://github.com/aimhubio/aim
34
+
35
+ Args:
36
+ repo: Aim repository directory or a `Repo` object that the Run object will
37
+ log results to. If not provided, a default repo will be set up in the
38
+ experiment directory (one level above trial directories).
39
+ experiment: Sets the `experiment` property of each Run object, which is the
40
+ experiment name associated with it. Can be used later to query
41
+ runs/sequences.
42
+ If not provided, the default will be the Tune experiment name set
43
+ by `RunConfig(name=...)`.
44
+ metrics: List of metric names (out of the metrics reported by Tune) to
45
+ track in Aim. If no metric are specified, log everything that
46
+ is reported.
47
+ aim_run_kwargs: Additional arguments that will be passed when creating the
48
+ individual `Run` objects for each trial. For the full list of arguments,
49
+ please see the Aim documentation:
50
+ https://aimstack.readthedocs.io/en/latest/refs/sdk.html
51
+ """
52
+
53
+ VALID_HPARAMS = (str, bool, int, float, list, type(None))
54
+ VALID_NP_HPARAMS = (np.bool_, np.float32, np.float64, np.int32, np.int64)
55
+
56
+ def __init__(
57
+ self,
58
+ repo: Optional[Union[str, "Repo"]] = None,
59
+ experiment_name: Optional[str] = None,
60
+ metrics: Optional[List[str]] = None,
61
+ **aim_run_kwargs,
62
+ ):
63
+ """
64
+ See help(AimLoggerCallback) for more information about parameters.
65
+ """
66
+ assert Run is not None, (
67
+ "aim must be installed!. You can install aim with"
68
+ " the command: `pip install aim`."
69
+ )
70
+ self._repo_path = repo
71
+ self._experiment_name = experiment_name
72
+ if not (bool(metrics) or metrics is None):
73
+ raise ValueError(
74
+ "`metrics` must either contain at least one metric name, or be None, "
75
+ "in which case all reported metrics will be logged to the aim repo."
76
+ )
77
+ self._metrics = metrics
78
+ self._aim_run_kwargs = aim_run_kwargs
79
+ self._trial_to_run: Dict["Trial", Run] = {}
80
+
81
+ def _create_run(self, trial: "Trial") -> Run:
82
+ """Initializes an Aim Run object for a given trial.
83
+
84
+ Args:
85
+ trial: The Tune trial that aim will track as a Run.
86
+
87
+ Returns:
88
+ Run: The created aim run for a specific trial.
89
+ """
90
+ experiment_dir = trial.local_experiment_path
91
+ run = Run(
92
+ repo=self._repo_path or experiment_dir,
93
+ experiment=self._experiment_name or trial.experiment_dir_name,
94
+ **self._aim_run_kwargs,
95
+ )
96
+ # Attach a few useful trial properties
97
+ run["trial_id"] = trial.trial_id
98
+ run["trial_log_dir"] = trial.path
99
+ trial_ip = trial.get_ray_actor_ip()
100
+ if trial_ip:
101
+ run["trial_ip"] = trial_ip
102
+ return run
103
+
104
+ def log_trial_start(self, trial: "Trial"):
105
+ if trial in self._trial_to_run:
106
+ # Cleanup an existing run if the trial has been restarted
107
+ self._trial_to_run[trial].close()
108
+
109
+ trial.init_local_path()
110
+ self._trial_to_run[trial] = self._create_run(trial)
111
+
112
+ if trial.evaluated_params:
113
+ self._log_trial_hparams(trial)
114
+
115
+ def log_trial_result(self, iteration: int, trial: "Trial", result: Dict):
116
+ tmp_result = result.copy()
117
+
118
+ step = result.get(TIMESTEPS_TOTAL, None) or result[TRAINING_ITERATION]
119
+
120
+ for k in ["config", "pid", "timestamp", TIME_TOTAL_S, TRAINING_ITERATION]:
121
+ tmp_result.pop(k, None) # not useful to log these
122
+
123
+ # `context` and `epoch` are special keys that users can report,
124
+ # which are treated as special aim metrics/configurations.
125
+ context = tmp_result.pop("context", None)
126
+ epoch = tmp_result.pop("epoch", None)
127
+
128
+ trial_run = self._trial_to_run[trial]
129
+ path = ["ray", "tune"]
130
+
131
+ flat_result = flatten_dict(tmp_result, delimiter="/")
132
+ valid_result = {}
133
+
134
+ for attr, value in flat_result.items():
135
+ if self._metrics and attr not in self._metrics:
136
+ continue
137
+
138
+ full_attr = "/".join(path + [attr])
139
+ if isinstance(value, tuple(VALID_SUMMARY_TYPES)) and not (
140
+ np.isnan(value) or np.isinf(value)
141
+ ):
142
+ valid_result[attr] = value
143
+ trial_run.track(
144
+ value=value,
145
+ name=full_attr,
146
+ epoch=epoch,
147
+ step=step,
148
+ context=context,
149
+ )
150
+ elif (isinstance(value, (list, tuple, set)) and len(value) > 0) or (
151
+ isinstance(value, np.ndarray) and value.size > 0
152
+ ):
153
+ valid_result[attr] = value
154
+
155
+ def log_trial_end(self, trial: "Trial", failed: bool = False):
156
+ trial_run = self._trial_to_run.pop(trial)
157
+ trial_run.close()
158
+
159
+ def _log_trial_hparams(self, trial: "Trial"):
160
+ params = flatten_dict(trial.evaluated_params, delimiter="/")
161
+ flat_params = flatten_dict(params)
162
+
163
+ scrubbed_params = {
164
+ k: v for k, v in flat_params.items() if isinstance(v, self.VALID_HPARAMS)
165
+ }
166
+
167
+ np_params = {
168
+ k: v.tolist()
169
+ for k, v in flat_params.items()
170
+ if isinstance(v, self.VALID_NP_HPARAMS)
171
+ }
172
+
173
+ scrubbed_params.update(np_params)
174
+ removed = {
175
+ k: v
176
+ for k, v in flat_params.items()
177
+ if not isinstance(v, self.VALID_HPARAMS + self.VALID_NP_HPARAMS)
178
+ }
179
+ if removed:
180
+ logger.info(
181
+ "Removed the following hyperparameter values when "
182
+ "logging to aim: %s",
183
+ str(removed),
184
+ )
185
+
186
+ run = self._trial_to_run[trial]
187
+ run["hparams"] = scrubbed_params
.venv/lib/python3.11/site-packages/ray/tune/logger/comet.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from ray.air.integrations.comet import CometLoggerCallback
2
+
3
+ CometLoggerCallback.__module__ = "ray.tune.logger.comet"
.venv/lib/python3.11/site-packages/ray/tune/logger/csv.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+ import logging
3
+ from pathlib import Path
4
+ from typing import TYPE_CHECKING, Dict, TextIO
5
+
6
+ from ray.air.constants import EXPR_PROGRESS_FILE
7
+ from ray.tune.logger.logger import _LOGGER_DEPRECATION_WARNING, Logger, LoggerCallback
8
+ from ray.tune.utils import flatten_dict
9
+ from ray.util.annotations import Deprecated, PublicAPI
10
+
11
+ if TYPE_CHECKING:
12
+ from ray.tune.experiment.trial import Trial # noqa: F401
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ @Deprecated(
18
+ message=_LOGGER_DEPRECATION_WARNING.format(
19
+ old="CSVLogger", new="ray.tune.csv.CSVLoggerCallback"
20
+ ),
21
+ warning=True,
22
+ )
23
+ @PublicAPI
24
+ class CSVLogger(Logger):
25
+ """Logs results to progress.csv under the trial directory.
26
+
27
+ Automatically flattens nested dicts in the result dict before writing
28
+ to csv:
29
+
30
+ {"a": {"b": 1, "c": 2}} -> {"a/b": 1, "a/c": 2}
31
+
32
+ """
33
+
34
+ def _init(self):
35
+ self._initialized = False
36
+
37
+ def _maybe_init(self):
38
+ """CSV outputted with Headers as first set of results."""
39
+ if not self._initialized:
40
+ progress_file = Path(self.logdir, EXPR_PROGRESS_FILE)
41
+ self._continuing = (
42
+ progress_file.exists() and progress_file.stat().st_size > 0
43
+ )
44
+ self._file = progress_file.open("a")
45
+ self._csv_out = None
46
+ self._initialized = True
47
+
48
+ def on_result(self, result: Dict):
49
+ self._maybe_init()
50
+
51
+ tmp = result.copy()
52
+ if "config" in tmp:
53
+ del tmp["config"]
54
+ result = flatten_dict(tmp, delimiter="/")
55
+ if self._csv_out is None:
56
+ self._csv_out = csv.DictWriter(self._file, result.keys())
57
+ if not self._continuing:
58
+ self._csv_out.writeheader()
59
+ self._csv_out.writerow(
60
+ {k: v for k, v in result.items() if k in self._csv_out.fieldnames}
61
+ )
62
+ self._file.flush()
63
+
64
+ def flush(self):
65
+ if self._initialized and not self._file.closed:
66
+ self._file.flush()
67
+
68
+ def close(self):
69
+ if self._initialized:
70
+ self._file.close()
71
+
72
+
73
+ @PublicAPI
74
+ class CSVLoggerCallback(LoggerCallback):
75
+ """Logs results to progress.csv under the trial directory.
76
+
77
+ Automatically flattens nested dicts in the result dict before writing
78
+ to csv:
79
+
80
+ {"a": {"b": 1, "c": 2}} -> {"a/b": 1, "a/c": 2}
81
+
82
+ """
83
+
84
+ _SAVED_FILE_TEMPLATES = [EXPR_PROGRESS_FILE]
85
+
86
+ def __init__(self):
87
+ self._trial_continue: Dict["Trial", bool] = {}
88
+ self._trial_files: Dict["Trial", TextIO] = {}
89
+ self._trial_csv: Dict["Trial", csv.DictWriter] = {}
90
+
91
+ def _setup_trial(self, trial: "Trial"):
92
+ if trial in self._trial_files:
93
+ self._trial_files[trial].close()
94
+
95
+ # Make sure logdir exists
96
+ trial.init_local_path()
97
+ local_file_path = Path(trial.local_path, EXPR_PROGRESS_FILE)
98
+
99
+ # Resume the file from remote storage.
100
+ self._restore_from_remote(EXPR_PROGRESS_FILE, trial)
101
+
102
+ self._trial_continue[trial] = (
103
+ local_file_path.exists() and local_file_path.stat().st_size > 0
104
+ )
105
+
106
+ self._trial_files[trial] = local_file_path.open("at")
107
+ self._trial_csv[trial] = None
108
+
109
+ def log_trial_result(self, iteration: int, trial: "Trial", result: Dict):
110
+ if trial not in self._trial_files:
111
+ self._setup_trial(trial)
112
+
113
+ tmp = result.copy()
114
+ tmp.pop("config", None)
115
+ result = flatten_dict(tmp, delimiter="/")
116
+
117
+ if not self._trial_csv[trial]:
118
+ self._trial_csv[trial] = csv.DictWriter(
119
+ self._trial_files[trial], result.keys()
120
+ )
121
+ if not self._trial_continue[trial]:
122
+ self._trial_csv[trial].writeheader()
123
+
124
+ self._trial_csv[trial].writerow(
125
+ {k: v for k, v in result.items() if k in self._trial_csv[trial].fieldnames}
126
+ )
127
+ self._trial_files[trial].flush()
128
+
129
+ def log_trial_end(self, trial: "Trial", failed: bool = False):
130
+ if trial not in self._trial_files:
131
+ return
132
+
133
+ del self._trial_csv[trial]
134
+ self._trial_files[trial].close()
135
+ del self._trial_files[trial]
.venv/lib/python3.11/site-packages/ray/tune/logger/json.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ from pathlib import Path
4
+ from typing import TYPE_CHECKING, Dict, TextIO
5
+
6
+ import numpy as np
7
+
8
+ import ray.cloudpickle as cloudpickle
9
+ from ray.air.constants import EXPR_PARAM_FILE, EXPR_PARAM_PICKLE_FILE, EXPR_RESULT_FILE
10
+ from ray.tune.logger.logger import _LOGGER_DEPRECATION_WARNING, Logger, LoggerCallback
11
+ from ray.tune.utils.util import SafeFallbackEncoder
12
+ from ray.util.annotations import Deprecated, PublicAPI
13
+
14
+ if TYPE_CHECKING:
15
+ from ray.tune.experiment.trial import Trial # noqa: F401
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+ tf = None
20
+ VALID_SUMMARY_TYPES = [int, float, np.float32, np.float64, np.int32, np.int64]
21
+
22
+
23
+ @Deprecated(
24
+ message=_LOGGER_DEPRECATION_WARNING.format(
25
+ old="JsonLogger", new="ray.tune.json.JsonLoggerCallback"
26
+ ),
27
+ warning=True,
28
+ )
29
+ @PublicAPI
30
+ class JsonLogger(Logger):
31
+ """Logs trial results in json format.
32
+
33
+ Also writes to a results file and param.json file when results or
34
+ configurations are updated. Experiments must be executed with the
35
+ JsonLogger to be compatible with the ExperimentAnalysis tool.
36
+ """
37
+
38
+ def _init(self):
39
+ self.update_config(self.config)
40
+ local_file = Path(self.logdir, EXPR_RESULT_FILE)
41
+ self.local_out = local_file.open("a")
42
+
43
+ def on_result(self, result: Dict):
44
+ json.dump(result, self, cls=SafeFallbackEncoder)
45
+ self.write("\n")
46
+ self.local_out.flush()
47
+
48
+ def write(self, b):
49
+ self.local_out.write(b)
50
+
51
+ def flush(self):
52
+ if not self.local_out.closed:
53
+ self.local_out.flush()
54
+
55
+ def close(self):
56
+ self.local_out.close()
57
+
58
+ def update_config(self, config: Dict):
59
+ self.config = config
60
+ config_out = Path(self.logdir, EXPR_PARAM_FILE)
61
+ with open(config_out, "w") as f:
62
+ json.dump(self.config, f, indent=2, sort_keys=True, cls=SafeFallbackEncoder)
63
+ config_pkl = Path(self.logdir, EXPR_PARAM_PICKLE_FILE)
64
+ with config_pkl.open("wb") as f:
65
+ cloudpickle.dump(self.config, f)
66
+
67
+
68
+ @PublicAPI
69
+ class JsonLoggerCallback(LoggerCallback):
70
+ """Logs trial results in json format.
71
+
72
+ Also writes to a results file and param.json file when results or
73
+ configurations are updated. Experiments must be executed with the
74
+ JsonLoggerCallback to be compatible with the ExperimentAnalysis tool.
75
+ """
76
+
77
+ _SAVED_FILE_TEMPLATES = [EXPR_RESULT_FILE, EXPR_PARAM_FILE, EXPR_PARAM_PICKLE_FILE]
78
+
79
+ def __init__(self):
80
+ self._trial_configs: Dict["Trial", Dict] = {}
81
+ self._trial_files: Dict["Trial", TextIO] = {}
82
+
83
+ def log_trial_start(self, trial: "Trial"):
84
+ if trial in self._trial_files:
85
+ self._trial_files[trial].close()
86
+
87
+ # Update config
88
+ self.update_config(trial, trial.config)
89
+
90
+ # Make sure logdir exists
91
+ trial.init_local_path()
92
+ local_file = Path(trial.local_path, EXPR_RESULT_FILE)
93
+
94
+ # Resume the file from remote storage.
95
+ self._restore_from_remote(EXPR_RESULT_FILE, trial)
96
+
97
+ self._trial_files[trial] = local_file.open("at")
98
+
99
+ def log_trial_result(self, iteration: int, trial: "Trial", result: Dict):
100
+ if trial not in self._trial_files:
101
+ self.log_trial_start(trial)
102
+ json.dump(result, self._trial_files[trial], cls=SafeFallbackEncoder)
103
+ self._trial_files[trial].write("\n")
104
+ self._trial_files[trial].flush()
105
+
106
+ def log_trial_end(self, trial: "Trial", failed: bool = False):
107
+ if trial not in self._trial_files:
108
+ return
109
+
110
+ self._trial_files[trial].close()
111
+ del self._trial_files[trial]
112
+
113
+ def update_config(self, trial: "Trial", config: Dict):
114
+ self._trial_configs[trial] = config
115
+
116
+ config_out = Path(trial.local_path, EXPR_PARAM_FILE)
117
+ with config_out.open("w") as f:
118
+ json.dump(
119
+ self._trial_configs[trial],
120
+ f,
121
+ indent=2,
122
+ sort_keys=True,
123
+ cls=SafeFallbackEncoder,
124
+ )
125
+
126
+ config_pkl = Path(trial.local_path, EXPR_PARAM_PICKLE_FILE)
127
+ with config_pkl.open("wb") as f:
128
+ cloudpickle.dump(self._trial_configs[trial], f)