Prompt48 commited on
Commit
52e4516
·
verified ·
1 Parent(s): 77b5dcc

Upload edit\Qwen3-TTS-test\.venv\Lib\site-packages\accelerate\tracking.py with huggingface_hub

Browse files
edit//Qwen3-TTS-test//.venv//Lib//site-packages//accelerate//tracking.py ADDED
@@ -0,0 +1,1326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Expectation:
16
+ # Provide a project dir name, then each type of logger gets stored in project/{`logging_dir`}
17
+
18
+ import json
19
+ import os
20
+ import time
21
+ from functools import wraps
22
+ from typing import Any, Optional, Union
23
+
24
+ import yaml
25
+ from packaging import version
26
+
27
+ from .logging import get_logger
28
+ from .state import PartialState
29
+ from .utils import (
30
+ LoggerType,
31
+ compare_versions,
32
+ is_aim_available,
33
+ is_clearml_available,
34
+ is_comet_ml_available,
35
+ is_dvclive_available,
36
+ is_mlflow_available,
37
+ is_swanlab_available,
38
+ is_tensorboard_available,
39
+ is_trackio_available,
40
+ is_wandb_available,
41
+ listify,
42
+ )
43
+
44
+
45
+ _available_trackers = []
46
+
47
+ if is_tensorboard_available():
48
+ _available_trackers.append(LoggerType.TENSORBOARD)
49
+
50
+ if is_wandb_available():
51
+ _available_trackers.append(LoggerType.WANDB)
52
+
53
+ if is_comet_ml_available():
54
+ _available_trackers.append(LoggerType.COMETML)
55
+
56
+ if is_aim_available():
57
+ _available_trackers.append(LoggerType.AIM)
58
+
59
+ if is_mlflow_available():
60
+ _available_trackers.append(LoggerType.MLFLOW)
61
+
62
+ if is_clearml_available():
63
+ _available_trackers.append(LoggerType.CLEARML)
64
+
65
+ if is_dvclive_available():
66
+ _available_trackers.append(LoggerType.DVCLIVE)
67
+
68
+ if is_swanlab_available():
69
+ _available_trackers.append(LoggerType.SWANLAB)
70
+
71
+ if is_trackio_available():
72
+ _available_trackers.append(LoggerType.TRACKIO)
73
+
74
+ logger = get_logger(__name__)
75
+
76
+
77
+ def on_main_process(function):
78
+ """
79
+ Decorator to selectively run the decorated function on the main process only based on the `main_process_only`
80
+ attribute in a class.
81
+
82
+ Checks at function execution rather than initialization time, not triggering the initialization of the
83
+ `PartialState`.
84
+ """
85
+
86
+ @wraps(function)
87
+ def execute_on_main_process(self, *args, **kwargs):
88
+ if getattr(self, "main_process_only", False):
89
+ return PartialState().on_main_process(function)(self, *args, **kwargs)
90
+ else:
91
+ return function(self, *args, **kwargs)
92
+
93
+ return execute_on_main_process
94
+
95
+
96
+ def get_available_trackers():
97
+ "Returns a list of all supported available trackers in the system"
98
+ return _available_trackers
99
+
100
+
101
+ class GeneralTracker:
102
+ """
103
+ A base Tracker class to be used for all logging integration implementations.
104
+
105
+ Each function should take in `**kwargs` that will automatically be passed in from a base dictionary provided to
106
+ [`Accelerator`].
107
+
108
+ Should implement `name`, `requires_logging_directory`, and `tracker` properties such that:
109
+
110
+ `name` (`str`): String representation of the tracker class name, such as "TensorBoard" `requires_logging_directory`
111
+ (`bool`): Whether the logger requires a directory to store their logs. `tracker` (`object`): Should return internal
112
+ tracking mechanism used by a tracker class (such as the `run` for wandb)
113
+
114
+ Implementations can also include a `main_process_only` (`bool`) attribute to toggle if relevant logging, init, and
115
+ other functions should occur on the main process or across all processes (by default will use `True`)
116
+ """
117
+
118
+ main_process_only = True
119
+
120
+ def __init__(self, _blank=False):
121
+ if not _blank:
122
+ err = ""
123
+ if not hasattr(self, "name"):
124
+ err += "`name`"
125
+ if not hasattr(self, "requires_logging_directory"):
126
+ if len(err) > 0:
127
+ err += ", "
128
+ err += "`requires_logging_directory`"
129
+
130
+ # as tracker is a @property that relies on post-init
131
+ if "tracker" not in dir(self):
132
+ if len(err) > 0:
133
+ err += ", "
134
+ err += "`tracker`"
135
+ if len(err) > 0:
136
+ raise NotImplementedError(
137
+ f"The implementation for this tracker class is missing the following "
138
+ f"required attributes. Please define them in the class definition: "
139
+ f"{err}"
140
+ )
141
+
142
+ def start(self):
143
+ """
144
+ Lazy initialization of the tracker inside Accelerator to avoid initializing PartialState before
145
+ InitProcessGroupKwargs.
146
+ """
147
+ pass
148
+
149
+ def store_init_configuration(self, values: dict):
150
+ """
151
+ Logs `values` as hyperparameters for the run. Implementations should use the experiment configuration
152
+ functionality of a tracking API.
153
+
154
+ Args:
155
+ values (Dictionary `str` to `bool`, `str`, `float` or `int`):
156
+ Values to be stored as initial hyperparameters as key-value pairs. The values need to have type `bool`,
157
+ `str`, `float`, `int`, or `None`.
158
+ """
159
+ pass
160
+
161
+ def log(self, values: dict, step: Optional[int], **kwargs):
162
+ """
163
+ Logs `values` to the current run. Base `log` implementations of a tracking API should go in here, along with
164
+ special behavior for the `step parameter.
165
+
166
+ Args:
167
+ values (Dictionary `str` to `str`, `float`, or `int`):
168
+ Values to be logged as key-value pairs. The values need to have type `str`, `float`, or `int`.
169
+ step (`int`, *optional*):
170
+ The run step. If included, the log will be affiliated with this step.
171
+ """
172
+ pass
173
+
174
+ def finish(self):
175
+ """
176
+ Should run any finalizing functions within the tracking API. If the API should not have one, just don't
177
+ overwrite that method.
178
+ """
179
+ pass
180
+
181
+
182
+ class TensorBoardTracker(GeneralTracker):
183
+ """
184
+ A `Tracker` class that supports `tensorboard`. Should be initialized at the start of your script.
185
+
186
+ Args:
187
+ run_name (`str`):
188
+ The name of the experiment run
189
+ logging_dir (`str`, `os.PathLike`):
190
+ Location for TensorBoard logs to be stored.
191
+ **kwargs (additional keyword arguments, *optional*):
192
+ Additional key word arguments passed along to the `tensorboard.SummaryWriter.__init__` method.
193
+ """
194
+
195
+ name = "tensorboard"
196
+ requires_logging_directory = True
197
+
198
+ def __init__(self, run_name: str, logging_dir: Union[str, os.PathLike], **kwargs):
199
+ super().__init__()
200
+ self.run_name = run_name
201
+ self.logging_dir_param = logging_dir
202
+ self.init_kwargs = kwargs
203
+
204
+ @on_main_process
205
+ def start(self):
206
+ try:
207
+ from torch.utils import tensorboard
208
+ except ModuleNotFoundError:
209
+ import tensorboardX as tensorboard
210
+ self.logging_dir = os.path.join(self.logging_dir_param, self.run_name)
211
+ self.writer = tensorboard.SummaryWriter(self.logging_dir, **self.init_kwargs)
212
+ logger.debug(f"Initialized TensorBoard project {self.run_name} logging to {self.logging_dir}")
213
+ logger.debug(
214
+ "Make sure to log any initial configurations with `self.store_init_configuration` before training!"
215
+ )
216
+
217
+ @property
218
+ def tracker(self):
219
+ return self.writer
220
+
221
+ @on_main_process
222
+ def store_init_configuration(self, values: dict):
223
+ """
224
+ Logs `values` as hyperparameters for the run. Should be run at the beginning of your experiment. Stores the
225
+ hyperparameters in a yaml file for future use.
226
+
227
+ Args:
228
+ values (Dictionary `str` to `bool`, `str`, `float` or `int`):
229
+ Values to be stored as initial hyperparameters as key-value pairs. The values need to have type `bool`,
230
+ `str`, `float`, `int`, or `None`.
231
+ """
232
+ self.writer.add_hparams(values, metric_dict={})
233
+ self.writer.flush()
234
+ project_run_name = time.time()
235
+ dir_name = os.path.join(self.logging_dir, str(project_run_name))
236
+ os.makedirs(dir_name, exist_ok=True)
237
+ with open(os.path.join(dir_name, "hparams.yml"), "w") as outfile:
238
+ try:
239
+ yaml.dump(values, outfile)
240
+ except yaml.representer.RepresenterError:
241
+ logger.error("Serialization to store hyperparameters failed")
242
+ raise
243
+ logger.debug("Stored initial configuration hyperparameters to TensorBoard and hparams yaml file")
244
+
245
+ @on_main_process
246
+ def log(self, values: dict, step: Optional[int] = None, **kwargs):
247
+ """
248
+ Logs `values` to the current run.
249
+
250
+ Args:
251
+ values (Dictionary `str` to `str`, `float`, `int` or `dict` of `str` to `float`/`int`):
252
+ Values to be logged as key-value pairs. The values need to have type `str`, `float`, `int` or `dict` of
253
+ `str` to `float`/`int`.
254
+ step (`int`, *optional*):
255
+ The run step. If included, the log will be affiliated with this step.
256
+ kwargs:
257
+ Additional key word arguments passed along to either `SummaryWriter.add_scaler`,
258
+ `SummaryWriter.add_text`, or `SummaryWriter.add_scalers` method based on the contents of `values`.
259
+ """
260
+ values = listify(values)
261
+ for k, v in values.items():
262
+ if isinstance(v, (int, float)):
263
+ self.writer.add_scalar(k, v, global_step=step, **kwargs)
264
+ elif isinstance(v, str):
265
+ self.writer.add_text(k, v, global_step=step, **kwargs)
266
+ elif isinstance(v, dict):
267
+ self.writer.add_scalars(k, v, global_step=step, **kwargs)
268
+ self.writer.flush()
269
+ logger.debug("Successfully logged to TensorBoard")
270
+
271
+ @on_main_process
272
+ def log_images(self, values: dict, step: Optional[int], **kwargs):
273
+ """
274
+ Logs `images` to the current run.
275
+
276
+ Args:
277
+ values (Dictionary `str` to `List` of `np.ndarray` or `PIL.Image`):
278
+ Values to be logged as key-value pairs. The values need to have type `List` of `np.ndarray` or
279
+ step (`int`, *optional*):
280
+ The run step. If included, the log will be affiliated with this step.
281
+ kwargs:
282
+ Additional key word arguments passed along to the `SummaryWriter.add_image` method.
283
+ """
284
+ for k, v in values.items():
285
+ self.writer.add_images(k, v, global_step=step, **kwargs)
286
+ logger.debug("Successfully logged images to TensorBoard")
287
+
288
+ @on_main_process
289
+ def finish(self):
290
+ """
291
+ Closes `TensorBoard` writer
292
+ """
293
+ self.writer.close()
294
+ logger.debug("TensorBoard writer closed")
295
+
296
+
297
+ class WandBTracker(GeneralTracker):
298
+ """
299
+ A `Tracker` class that supports `wandb`. Should be initialized at the start of your script.
300
+
301
+ Args:
302
+ run_name (`str`):
303
+ The name of the experiment run.
304
+ **kwargs (additional keyword arguments, *optional*):
305
+ Additional key word arguments passed along to the `wandb.init` method.
306
+ """
307
+
308
+ name = "wandb"
309
+ requires_logging_directory = False
310
+ main_process_only = False
311
+
312
+ def __init__(self, run_name: str, **kwargs):
313
+ super().__init__()
314
+ self.run_name = run_name
315
+ self.init_kwargs = kwargs
316
+
317
+ @on_main_process
318
+ def start(self):
319
+ import wandb
320
+
321
+ self.run = wandb.init(project=self.run_name, **self.init_kwargs)
322
+ logger.debug(f"Initialized WandB project {self.run_name}")
323
+ logger.debug(
324
+ "Make sure to log any initial configurations with `self.store_init_configuration` before training!"
325
+ )
326
+
327
+ @property
328
+ def tracker(self):
329
+ return self.run
330
+
331
+ @on_main_process
332
+ def store_init_configuration(self, values: dict):
333
+ """
334
+ Logs `values` as hyperparameters for the run. Should be run at the beginning of your experiment.
335
+
336
+ Args:
337
+ values (Dictionary `str` to `bool`, `str`, `float` or `int`):
338
+ Values to be stored as initial hyperparameters as key-value pairs. The values need to have type `bool`,
339
+ `str`, `float`, `int`, or `None`.
340
+ """
341
+ import wandb
342
+
343
+ if os.environ.get("WANDB_MODE") == "offline":
344
+ # In offline mode, restart wandb with config included
345
+ if hasattr(self, "run") and self.run:
346
+ self.run.finish()
347
+
348
+ init_kwargs = self.init_kwargs.copy()
349
+ init_kwargs["config"] = values
350
+ self.run = wandb.init(project=self.run_name, **init_kwargs)
351
+ else:
352
+ wandb.config.update(values, allow_val_change=True)
353
+ logger.debug("Stored initial configuration hyperparameters to WandB")
354
+
355
+ @on_main_process
356
+ def log(self, values: dict, step: Optional[int] = None, **kwargs):
357
+ """
358
+ Logs `values` to the current run.
359
+
360
+ Args:
361
+ values (Dictionary `str` to `str`, `float`, `int` or `dict` of `str` to `float`/`int`):
362
+ Values to be logged as key-value pairs. The values need to have type `str`, `float`, `int` or `dict` of
363
+ `str` to `float`/`int`.
364
+ step (`int`, *optional*):
365
+ The run step. If included, the log will be affiliated with this step.
366
+ kwargs:
367
+ Additional key word arguments passed along to the `wandb.log` method.
368
+ """
369
+ self.run.log(values, step=step, **kwargs)
370
+ logger.debug("Successfully logged to WandB")
371
+
372
+ @on_main_process
373
+ def log_images(self, values: dict, step: Optional[int] = None, **kwargs):
374
+ """
375
+ Logs `images` to the current run.
376
+
377
+ Args:
378
+ values (Dictionary `str` to `List` of `np.ndarray` or `PIL.Image`):
379
+ Values to be logged as key-value pairs. The values need to have type `List` of `np.ndarray` or
380
+ step (`int`, *optional*):
381
+ The run step. If included, the log will be affiliated with this step.
382
+ kwargs:
383
+ Additional key word arguments passed along to the `wandb.log` method.
384
+ """
385
+ import wandb
386
+
387
+ for k, v in values.items():
388
+ self.log({k: [wandb.Image(image) for image in v]}, step=step, **kwargs)
389
+ logger.debug("Successfully logged images to WandB")
390
+
391
+ @on_main_process
392
+ def log_table(
393
+ self,
394
+ table_name: str,
395
+ columns: Optional[list[str]] = None,
396
+ data: Optional[list[list[Any]]] = None,
397
+ dataframe: Any = None,
398
+ step: Optional[int] = None,
399
+ **kwargs,
400
+ ):
401
+ """
402
+ Log a Table containing any object type (text, image, audio, video, molecule, html, etc). Can be defined either
403
+ with `columns` and `data` or with `dataframe`.
404
+
405
+ Args:
406
+ table_name (`str`):
407
+ The name to give to the logged table on the wandb workspace
408
+ columns (list of `str`, *optional*):
409
+ The name of the columns on the table
410
+ data (List of List of Any data type, *optional*):
411
+ The data to be logged in the table
412
+ dataframe (Any data type, *optional*):
413
+ The data to be logged in the table
414
+ step (`int`, *optional*):
415
+ The run step. If included, the log will be affiliated with this step.
416
+ """
417
+ import wandb
418
+
419
+ values = {table_name: wandb.Table(columns=columns, data=data, dataframe=dataframe)}
420
+ self.log(values, step=step, **kwargs)
421
+
422
+ @on_main_process
423
+ def finish(self):
424
+ """
425
+ Closes `wandb` writer
426
+ """
427
+ self.run.finish()
428
+ logger.debug("WandB run closed")
429
+
430
+
431
+ class TrackioTracker(GeneralTracker):
432
+ """
433
+ A `Tracker` class that supports `trackio`. Should be initialized at the start of your script.
434
+
435
+ Args:
436
+ run_name (`str`):
437
+ The name of the experiment run. Will be used as the `project` name when instantiating trackio.
438
+ **kwargs (additional keyword arguments, *optional*):
439
+ Additional key word arguments passed along to the `trackio.init` method. Refer to this
440
+ [init](https://github.com/gradio-app/trackio/blob/814809552310468b13f84f33764f1369b4e5136c/trackio/__init__.py#L22)
441
+ to see all supported key word arguments.
442
+ """
443
+
444
+ name = "trackio"
445
+ requires_logging_directory = False
446
+ main_process_only = False
447
+
448
+ def __init__(self, run_name: str, **kwargs):
449
+ super().__init__()
450
+ self.run_name = run_name
451
+ self.init_kwargs = kwargs
452
+
453
+ @on_main_process
454
+ def start(self):
455
+ import trackio
456
+
457
+ self.run = trackio.init(project=self.run_name, **self.init_kwargs)
458
+ logger.debug(f"Initialized trackio project {self.run_name}")
459
+ logger.debug(
460
+ "Make sure to log any initial configurations with `self.store_init_configuration` before training!"
461
+ )
462
+
463
+ @property
464
+ def tracker(self):
465
+ return self.run
466
+
467
+ @on_main_process
468
+ def store_init_configuration(self, values: dict):
469
+ """
470
+ Logs `values` as hyperparameters for the run. Should be run at the beginning of your experiment.
471
+
472
+ Args:
473
+ values (Dictionary `str` to `bool`, `str`, `float` or `int`):
474
+ Values to be stored as initial hyperparameters as key-value pairs. The values need to have type `bool`,
475
+ `str`, `float`, `int`, or `None`.
476
+ """
477
+ import trackio
478
+
479
+ trackio.config.update(values, allow_val_change=True)
480
+ logger.debug("Stored initial configuration hyperparameters to trackio")
481
+
482
+ @on_main_process
483
+ def log(self, values: dict, step: Optional[int] = None, **kwargs):
484
+ """
485
+ Logs `values` to the current run.
486
+
487
+ Args:
488
+ values (Dictionary `str` to `str`, `float`, `int` or `dict` of `str` to `float`/`int`):
489
+ Values to be logged as key-value pairs. The values need to have type `str`, `float`, `int` or `dict` of
490
+ `str` to `float`/`int`.
491
+ step (`int`, *optional*):
492
+ The run step. If included, the log will be affiliated with this step.
493
+ kwargs:
494
+ Additional key word arguments passed along to the `trackio.log` method.
495
+ """
496
+ self.run.log(values, **kwargs)
497
+ logger.debug("Successfully logged to trackio")
498
+
499
+ @on_main_process
500
+ def finish(self):
501
+ """
502
+ Closes `trackio` run
503
+ """
504
+ self.run.finish()
505
+ logger.debug("trackio run closed")
506
+
507
+
508
+ class CometMLTracker(GeneralTracker):
509
+ """
510
+ A `Tracker` class that supports `comet_ml`. Should be initialized at the start of your script.
511
+
512
+ API keys must be stored in a Comet config file.
513
+
514
+ Note:
515
+ For `comet_ml` versions < 3.41.0, additional keyword arguments are passed to `comet_ml.Experiment` instead:
516
+ https://www.comet.com/docs/v2/api-and-sdk/python-sdk/reference/Experiment/#comet_ml.Experiment.__init__
517
+
518
+ Args:
519
+ run_name (`str`):
520
+ The name of the experiment run.
521
+ **kwargs (additional keyword arguments, *optional*):
522
+ Additional key word arguments passed along to the `comet_ml.start` method:
523
+ https://www.comet.com/docs/v2/api-and-sdk/python-sdk/reference/start/
524
+ """
525
+
526
+ name = "comet_ml"
527
+ requires_logging_directory = False
528
+
529
+ def __init__(self, run_name: str, **kwargs):
530
+ super().__init__()
531
+ self.run_name = run_name
532
+ self.init_kwargs = kwargs
533
+
534
+ @on_main_process
535
+ def start(self):
536
+ import comet_ml
537
+
538
+ comet_version = version.parse(comet_ml.__version__)
539
+ if compare_versions(comet_version, ">=", "3.41.0"):
540
+ self.writer = comet_ml.start(project_name=self.run_name, **self.init_kwargs)
541
+ else:
542
+ logger.info("Update `comet_ml` (>=3.41.0) for experiment reuse and offline support.")
543
+ self.writer = comet_ml.Experiment(project_name=self.run_name, **self.init_kwargs)
544
+
545
+ logger.debug(f"Initialized CometML project {self.run_name}")
546
+ logger.debug(
547
+ "Make sure to log any initial configurations with `self.store_init_configuration` before training!"
548
+ )
549
+
550
+ @property
551
+ def tracker(self):
552
+ return self.writer
553
+
554
+ @on_main_process
555
+ def store_init_configuration(self, values: dict):
556
+ """
557
+ Logs `values` as hyperparameters for the run. Should be run at the beginning of your experiment.
558
+
559
+ Args:
560
+ values (Dictionary `str` to `bool`, `str`, `float` or `int`):
561
+ Values to be stored as initial hyperparameters as key-value pairs. The values need to have type `bool`,
562
+ `str`, `float`, `int`, or `None`.
563
+ """
564
+ self.writer.log_parameters(values)
565
+ logger.debug("Stored initial configuration hyperparameters to Comet")
566
+
567
+ @on_main_process
568
+ def log(self, values: dict, step: Optional[int] = None, **kwargs):
569
+ """
570
+ Logs `values` to the current run.
571
+
572
+ Args:
573
+ values (Dictionary `str` to `str`, `float`, `int` or `dict` of `str` to `float`/`int`):
574
+ Values to be logged as key-value pairs. The values need to have type `str`, `float`, `int` or `dict` of
575
+ `str` to `float`/`int`.
576
+ step (`int`, *optional*):
577
+ The run step. If included, the log will be affiliated with this step.
578
+ kwargs:
579
+ Additional key word arguments passed along to either `Experiment.log_metric`, `Experiment.log_other`,
580
+ or `Experiment.log_metrics` method based on the contents of `values`.
581
+ """
582
+ if step is not None:
583
+ self.writer.set_step(step)
584
+ for k, v in values.items():
585
+ if isinstance(v, (int, float)):
586
+ self.writer.log_metric(k, v, step=step, **kwargs)
587
+ elif isinstance(v, str):
588
+ self.writer.log_other(k, v, **kwargs)
589
+ elif isinstance(v, dict):
590
+ self.writer.log_metrics(v, step=step, **kwargs)
591
+ logger.debug("Successfully logged to Comet")
592
+
593
+ @on_main_process
594
+ def finish(self):
595
+ """
596
+ Flush `comet-ml` writer
597
+ """
598
+ self.writer.end()
599
+ logger.debug("Comet run flushed")
600
+
601
+
602
+ class AimTracker(GeneralTracker):
603
+ """
604
+ A `Tracker` class that supports `aim`. Should be initialized at the start of your script.
605
+
606
+ Args:
607
+ run_name (`str`):
608
+ The name of the experiment run.
609
+ **kwargs (additional keyword arguments, *optional*):
610
+ Additional key word arguments passed along to the `Run.__init__` method.
611
+ """
612
+
613
+ name = "aim"
614
+ requires_logging_directory = True
615
+
616
+ def __init__(self, run_name: str, logging_dir: Optional[Union[str, os.PathLike]] = ".", **kwargs):
617
+ super().__init__()
618
+ self.run_name = run_name
619
+ self.aim_repo_path = logging_dir
620
+ self.init_kwargs = kwargs
621
+
622
+ @on_main_process
623
+ def start(self):
624
+ from aim import Run
625
+
626
+ self.writer = Run(repo=self.aim_repo_path, **self.init_kwargs)
627
+ self.writer.name = self.run_name
628
+ logger.debug(f"Initialized Aim project {self.run_name}")
629
+ logger.debug(
630
+ "Make sure to log any initial configurations with `self.store_init_configuration` before training!"
631
+ )
632
+
633
+ @property
634
+ def tracker(self):
635
+ return self.writer
636
+
637
+ @on_main_process
638
+ def store_init_configuration(self, values: dict):
639
+ """
640
+ Logs `values` as hyperparameters for the run. Should be run at the beginning of your experiment.
641
+
642
+ Args:
643
+ values (`dict`):
644
+ Values to be stored as initial hyperparameters as key-value pairs.
645
+ """
646
+ self.writer["hparams"] = values
647
+
648
+ @on_main_process
649
+ def log(self, values: dict, step: Optional[int], **kwargs):
650
+ """
651
+ Logs `values` to the current run.
652
+
653
+ Args:
654
+ values (`dict`):
655
+ Values to be logged as key-value pairs.
656
+ step (`int`, *optional*):
657
+ The run step. If included, the log will be affiliated with this step.
658
+ kwargs:
659
+ Additional key word arguments passed along to the `Run.track` method.
660
+ """
661
+ # Note: replace this with the dictionary support when merged
662
+ for key, value in values.items():
663
+ self.writer.track(value, name=key, step=step, **kwargs)
664
+
665
+ @on_main_process
666
+ def log_images(self, values: dict, step: Optional[int] = None, kwargs: Optional[dict[str, dict]] = None):
667
+ """
668
+ Logs `images` to the current run.
669
+
670
+ Args:
671
+ values (`Dict[str, Union[np.ndarray, PIL.Image, Tuple[np.ndarray, str], Tuple[PIL.Image, str]]]`):
672
+ Values to be logged as key-value pairs. The values need to have type `np.ndarray` or PIL.Image. If a
673
+ tuple is provided, the first element should be the image and the second element should be the caption.
674
+ step (`int`, *optional*):
675
+ The run step. If included, the log will be affiliated with this step.
676
+ kwargs (`Dict[str, dict]`):
677
+ Additional key word arguments passed along to the `Run.Image` and `Run.track` method specified by the
678
+ keys `aim_image` and `track`, respectively.
679
+ """
680
+ import aim
681
+
682
+ aim_image_kw = {}
683
+ track_kw = {}
684
+
685
+ if kwargs is not None:
686
+ aim_image_kw = kwargs.get("aim_image", {})
687
+ track_kw = kwargs.get("track", {})
688
+
689
+ for key, value in values.items():
690
+ if isinstance(value, tuple):
691
+ img, caption = value
692
+ else:
693
+ img, caption = value, ""
694
+ aim_image = aim.Image(img, caption=caption, **aim_image_kw)
695
+ self.writer.track(aim_image, name=key, step=step, **track_kw)
696
+
697
+ @on_main_process
698
+ def finish(self):
699
+ """
700
+ Closes `aim` writer
701
+ """
702
+ self.writer.close()
703
+
704
+
705
+ class MLflowTracker(GeneralTracker):
706
+ """
707
+ A `Tracker` class that supports `mlflow`. Should be initialized at the start of your script.
708
+
709
+ Args:
710
+ experiment_name (`str`, *optional*):
711
+ Name of the experiment. Environment variable MLFLOW_EXPERIMENT_NAME has priority over this argument.
712
+ logging_dir (`str` or `os.PathLike`, defaults to `"."`):
713
+ Location for mlflow logs to be stored.
714
+ run_id (`str`, *optional*):
715
+ If specified, get the run with the specified UUID and log parameters and metrics under that run. The run’s
716
+ end time is unset and its status is set to running, but the run’s other attributes (source_version,
717
+ source_type, etc.) are not changed. Environment variable MLFLOW_RUN_ID has priority over this argument.
718
+ tags (`Dict[str, str]`, *optional*):
719
+ An optional `dict` of `str` keys and values, or a `str` dump from a `dict`, to set as tags on the run. If a
720
+ run is being resumed, these tags are set on the resumed run. If a new run is being created, these tags are
721
+ set on the new run. Environment variable MLFLOW_TAGS has priority over this argument.
722
+ nested_run (`bool`, *optional*, defaults to `False`):
723
+ Controls whether run is nested in parent run. True creates a nested run. Environment variable
724
+ MLFLOW_NESTED_RUN has priority over this argument.
725
+ run_name (`str`, *optional*):
726
+ Name of new run (stored as a mlflow.runName tag). Used only when `run_id` is unspecified.
727
+ description (`str`, *optional*):
728
+ An optional string that populates the description box of the run. If a run is being resumed, the
729
+ description is set on the resumed run. If a new run is being created, the description is set on the new
730
+ run.
731
+ """
732
+
733
+ name = "mlflow"
734
+ requires_logging_directory = False
735
+
736
+ def __init__(
737
+ self,
738
+ experiment_name: Optional[str] = None,
739
+ logging_dir: Optional[Union[str, os.PathLike]] = None,
740
+ run_id: Optional[str] = None,
741
+ tags: Optional[Union[dict[str, Any], str]] = None,
742
+ nested_run: Optional[bool] = False,
743
+ run_name: Optional[str] = None,
744
+ description: Optional[str] = None,
745
+ ):
746
+ experiment_name = os.environ.get("MLFLOW_EXPERIMENT_NAME", experiment_name)
747
+ run_id = os.environ.get("MLFLOW_RUN_ID", run_id)
748
+ tags = os.environ.get("MLFLOW_TAGS", tags)
749
+ if isinstance(tags, str):
750
+ tags = json.loads(tags)
751
+
752
+ nested_run = os.environ.get("MLFLOW_NESTED_RUN", nested_run)
753
+
754
+ self.experiment_name = experiment_name
755
+ self.logging_dir = logging_dir
756
+ self.run_id = run_id
757
+ self.tags = tags
758
+ self.nested_run = nested_run
759
+ self.run_name = run_name
760
+ self.description = description
761
+
762
+ @on_main_process
763
+ def start(self):
764
+ import mlflow
765
+
766
+ exps = mlflow.search_experiments(filter_string=f"name = '{self.experiment_name}'")
767
+ if len(exps) > 0:
768
+ if len(exps) > 1:
769
+ logger.warning("Multiple experiments with the same name found. Using first one.")
770
+ experiment_id = exps[0].experiment_id
771
+ else:
772
+ experiment_id = mlflow.create_experiment(
773
+ name=self.experiment_name,
774
+ artifact_location=self.logging_dir,
775
+ tags=self.tags,
776
+ )
777
+
778
+ self.active_run = mlflow.start_run(
779
+ run_id=self.run_id,
780
+ experiment_id=experiment_id,
781
+ run_name=self.run_name,
782
+ nested=self.nested_run,
783
+ tags=self.tags,
784
+ description=self.description,
785
+ )
786
+
787
+ logger.debug(f"Initialized mlflow experiment {self.experiment_name}")
788
+ logger.debug(
789
+ "Make sure to log any initial configurations with `self.store_init_configuration` before training!"
790
+ )
791
+
792
+ @property
793
+ def tracker(self):
794
+ return self.active_run
795
+
796
+ @on_main_process
797
+ def store_init_configuration(self, values: dict):
798
+ """
799
+ Logs `values` as hyperparameters for the run. Should be run at the beginning of your experiment.
800
+
801
+ Args:
802
+ values (`dict`):
803
+ Values to be stored as initial hyperparameters as key-value pairs.
804
+ """
805
+ import mlflow
806
+
807
+ for name, value in list(values.items()):
808
+ # internally, all values are converted to str in MLflow
809
+ if len(str(value)) > mlflow.utils.validation.MAX_PARAM_VAL_LENGTH:
810
+ logger.warning_once(
811
+ f'Accelerate is attempting to log a value of "{value}" for key "{name}" as a parameter. MLflow\'s'
812
+ f" log_param() only accepts values no longer than {mlflow.utils.validation.MAX_PARAM_VAL_LENGTH} characters so we dropped this attribute."
813
+ )
814
+ del values[name]
815
+
816
+ values_list = list(values.items())
817
+
818
+ # MLflow cannot log more than 100 values in one go, so we have to split it
819
+ for i in range(0, len(values_list), mlflow.utils.validation.MAX_PARAMS_TAGS_PER_BATCH):
820
+ mlflow.log_params(dict(values_list[i : i + mlflow.utils.validation.MAX_PARAMS_TAGS_PER_BATCH]))
821
+
822
+ logger.debug("Stored initial configuration hyperparameters to MLflow")
823
+
824
+ @on_main_process
825
+ def log(self, values: dict, step: Optional[int]):
826
+ """
827
+ Logs `values` to the current run.
828
+
829
+ Args:
830
+ values (`dict`):
831
+ Values to be logged as key-value pairs.
832
+ step (`int`, *optional*):
833
+ The run step. If included, the log will be affiliated with this step.
834
+ """
835
+ metrics = {}
836
+ for k, v in values.items():
837
+ if isinstance(v, (int, float)):
838
+ metrics[k] = v
839
+ else:
840
+ logger.warning_once(
841
+ f'MLflowTracker is attempting to log a value of "{v}" of type {type(v)} for key "{k}" as a metric. '
842
+ "MLflow's log_metric() only accepts float and int types so we dropped this attribute."
843
+ )
844
+ import mlflow
845
+
846
+ mlflow.log_metrics(metrics, step=step)
847
+ logger.debug("Successfully logged to mlflow")
848
+
849
+ @on_main_process
850
+ def log_figure(self, figure: Any, artifact_file: str, **save_kwargs):
851
+ """
852
+ Logs an figure to the current run.
853
+
854
+ Args:
855
+ figure (Any):
856
+ The figure to be logged.
857
+ artifact_file (`str`, *optional*):
858
+ The run-relative artifact file path in posixpath format to which the image is saved.
859
+ If not provided, the image is saved to a default location.
860
+ **kwargs:
861
+ Additional keyword arguments passed to the underlying mlflow.log_image function.
862
+ """
863
+ import mlflow
864
+
865
+ mlflow.log_figure(figure=figure, artifact_file=artifact_file, **save_kwargs)
866
+ logger.debug("Successfully logged image to mlflow")
867
+
868
+ @on_main_process
869
+ def log_artifacts(self, local_dir: str, artifact_path: Optional[str] = None):
870
+ """
871
+ Logs an artifacts (all content of a dir) to the current run.
872
+
873
+ local_dir (`str`):
874
+ Path to the directory to be logged as an artifact.
875
+ artifact_path (`str`, *optional*):
876
+ Directory within the run's artifact directory where the artifact will be logged. If omitted, the
877
+ artifact will be logged to the root of the run's artifact directory. The run step. If included, the
878
+ artifact will be affiliated with this step.
879
+ """
880
+ import mlflow
881
+
882
+ mlflow.log_artifacts(local_dir=local_dir, artifact_path=artifact_path)
883
+ logger.debug("Successfully logged artofact to mlflow")
884
+
885
+ @on_main_process
886
+ def log_artifact(self, local_path: str, artifact_path: Optional[str] = None):
887
+ """
888
+ Logs an artifact (file) to the current run.
889
+
890
+ local_path (`str`):
891
+ Path to the file to be logged as an artifact.
892
+ artifact_path (`str`, *optional*):
893
+ Directory within the run's artifact directory where the artifact will be logged. If omitted, the
894
+ artifact will be logged to the root of the run's artifact directory. The run step. If included, the
895
+ artifact will be affiliated with this step.
896
+ """
897
+ import mlflow
898
+
899
+ mlflow.log_artifact(local_path=local_path, artifact_path=artifact_path)
900
+ logger.debug("Successfully logged artofact to mlflow")
901
+
902
+ @on_main_process
903
+ def finish(self):
904
+ """
905
+ End the active MLflow run.
906
+ """
907
+ import mlflow
908
+
909
+ mlflow.end_run()
910
+
911
+
912
+ class ClearMLTracker(GeneralTracker):
913
+ """
914
+ A `Tracker` class that supports `clearml`. Should be initialized at the start of your script.
915
+
916
+ Args:
917
+ run_name (`str`, *optional*):
918
+ Name of the experiment. Environment variables `CLEARML_PROJECT` and `CLEARML_TASK` have priority over this
919
+ argument.
920
+ **kwargs (additional keyword arguments, *optional*):
921
+ Kwargs passed along to the `Task.__init__` method.
922
+ """
923
+
924
+ name = "clearml"
925
+ requires_logging_directory = False
926
+
927
+ def __init__(self, run_name: Optional[str] = None, **kwargs):
928
+ super().__init__()
929
+ self.user_provided_run_name = run_name
930
+ self._initialized_externally = False
931
+ self.init_kwargs = kwargs
932
+
933
+ @on_main_process
934
+ def start(self):
935
+ from clearml import Task
936
+
937
+ current_task = Task.current_task()
938
+ if current_task:
939
+ self._initialized_externally = True
940
+ self.task = current_task
941
+ return
942
+
943
+ task_init_args = {**self.init_kwargs}
944
+ task_init_args.setdefault("project_name", os.environ.get("CLEARML_PROJECT", self.user_provided_run_name))
945
+ task_init_args.setdefault("task_name", os.environ.get("CLEARML_TASK", self.user_provided_run_name))
946
+ self.task = Task.init(**task_init_args)
947
+
948
+ @property
949
+ def tracker(self):
950
+ return self.task
951
+
952
+ @on_main_process
953
+ def store_init_configuration(self, values: dict):
954
+ """
955
+ Connect configuration dictionary to the Task object. Should be run at the beginning of your experiment.
956
+
957
+ Args:
958
+ values (`dict`):
959
+ Values to be stored as initial hyperparameters as key-value pairs.
960
+ """
961
+ return self.task.connect_configuration(values)
962
+
963
+ @on_main_process
964
+ def log(self, values: dict[str, Union[int, float]], step: Optional[int] = None, **kwargs):
965
+ """
966
+ Logs `values` dictionary to the current run. The dictionary keys must be strings. The dictionary values must be
967
+ ints or floats
968
+
969
+ Args:
970
+ values (`Dict[str, Union[int, float]]`):
971
+ Values to be logged as key-value pairs. If the key starts with 'eval_'/'test_'/'train_', the value will
972
+ be reported under the 'eval'/'test'/'train' series and the respective prefix will be removed.
973
+ Otherwise, the value will be reported under the 'train' series, and no prefix will be removed.
974
+ step (`int`, *optional*):
975
+ If specified, the values will be reported as scalars, with the iteration number equal to `step`.
976
+ Otherwise they will be reported as single values.
977
+ kwargs:
978
+ Additional key word arguments passed along to the `clearml.Logger.report_single_value` or
979
+ `clearml.Logger.report_scalar` methods.
980
+ """
981
+ clearml_logger = self.task.get_logger()
982
+ for k, v in values.items():
983
+ if not isinstance(v, (int, float)):
984
+ logger.warning_once(
985
+ "Accelerator is attempting to log a value of "
986
+ f'"{v}" of type {type(v)} for key "{k}" as a scalar. '
987
+ "This invocation of ClearML logger's report_scalar() "
988
+ "is incorrect so we dropped this attribute."
989
+ )
990
+ continue
991
+ if step is None:
992
+ clearml_logger.report_single_value(name=k, value=v, **kwargs)
993
+ continue
994
+ title, series = ClearMLTracker._get_title_series(k)
995
+ clearml_logger.report_scalar(title=title, series=series, value=v, iteration=step, **kwargs)
996
+
997
+ @on_main_process
998
+ def log_images(self, values: dict, step: Optional[int] = None, **kwargs):
999
+ """
1000
+ Logs `images` to the current run.
1001
+
1002
+ Args:
1003
+ values (`Dict[str, List[Union[np.ndarray, PIL.Image]]`):
1004
+ Values to be logged as key-value pairs. The values need to have type `List` of `np.ndarray` or
1005
+ step (`int`, *optional*):
1006
+ The run step. If included, the log will be affiliated with this step.
1007
+ kwargs:
1008
+ Additional key word arguments passed along to the `clearml.Logger.report_image` method.
1009
+ """
1010
+ clearml_logger = self.task.get_logger()
1011
+ for k, v in values.items():
1012
+ title, series = ClearMLTracker._get_title_series(k)
1013
+ clearml_logger.report_image(title=title, series=series, iteration=step, image=v, **kwargs)
1014
+
1015
+ @on_main_process
1016
+ def log_table(
1017
+ self,
1018
+ table_name: str,
1019
+ columns: Optional[list[str]] = None,
1020
+ data: Optional[list[list[Any]]] = None,
1021
+ dataframe: Any = None,
1022
+ step: Optional[int] = None,
1023
+ **kwargs,
1024
+ ):
1025
+ """
1026
+ Log a Table to the task. Can be defined eitherwith `columns` and `data` or with `dataframe`.
1027
+
1028
+ Args:
1029
+ table_name (`str`):
1030
+ The name of the table
1031
+ columns (list of `str`, *optional*):
1032
+ The name of the columns on the table
1033
+ data (List of List of Any data type, *optional*):
1034
+ The data to be logged in the table. If `columns` is not specified, then the first entry in data will be
1035
+ the name of the columns of the table
1036
+ dataframe (Any data type, *optional*):
1037
+ The data to be logged in the table
1038
+ step (`int`, *optional*):
1039
+ The run step. If included, the log will be affiliated with this step.
1040
+ kwargs:
1041
+ Additional key word arguments passed along to the `clearml.Logger.report_table` method.
1042
+ """
1043
+ to_report = dataframe
1044
+ if dataframe is None:
1045
+ if data is None:
1046
+ raise ValueError(
1047
+ "`ClearMLTracker.log_table` requires that `data` to be supplied if `dataframe` is `None`"
1048
+ )
1049
+ to_report = [columns] + data if columns else data
1050
+ title, series = ClearMLTracker._get_title_series(table_name)
1051
+ self.task.get_logger().report_table(title=title, series=series, table_plot=to_report, iteration=step, **kwargs)
1052
+
1053
+ @on_main_process
1054
+ def finish(self):
1055
+ """
1056
+ Close the ClearML task. If the task was initialized externally (e.g. by manually calling `Task.init`), this
1057
+ function is a noop
1058
+ """
1059
+ if self.task and not self._initialized_externally:
1060
+ self.task.close()
1061
+
1062
+ @staticmethod
1063
+ def _get_title_series(name):
1064
+ for prefix in ["eval", "test", "train"]:
1065
+ if name.startswith(prefix + "_"):
1066
+ return name[len(prefix) + 1 :], prefix
1067
+ return name, "train"
1068
+
1069
+
1070
+ class DVCLiveTracker(GeneralTracker):
1071
+ """
1072
+ A `Tracker` class that supports `dvclive`. Should be initialized at the start of your script.
1073
+
1074
+ Args:
1075
+ run_name (`str`, *optional*):
1076
+ Ignored for dvclive. See `kwargs` instead.
1077
+ kwargs:
1078
+ Additional key word arguments passed along to [`dvclive.Live()`](https://dvc.org/doc/dvclive/live).
1079
+
1080
+ Example:
1081
+
1082
+ ```py
1083
+ from accelerate import Accelerator
1084
+
1085
+ accelerator = Accelerator(log_with="dvclive")
1086
+ accelerator.init_trackers(project_name="my_project", init_kwargs={"dvclive": {"dir": "my_directory"}})
1087
+ ```
1088
+ """
1089
+
1090
+ name = "dvclive"
1091
+ requires_logging_directory = False
1092
+
1093
+ def __init__(self, run_name: Optional[str] = None, live: Optional[Any] = None, **kwargs):
1094
+ super().__init__()
1095
+ self.live = live
1096
+ self.init_kwargs = kwargs
1097
+
1098
+ @on_main_process
1099
+ def start(self):
1100
+ from dvclive import Live
1101
+
1102
+ self.live = self.live if self.live is not None else Live(**self.init_kwargs)
1103
+
1104
+ @property
1105
+ def tracker(self):
1106
+ return self.live
1107
+
1108
+ @on_main_process
1109
+ def store_init_configuration(self, values: dict):
1110
+ """
1111
+ Logs `values` as hyperparameters for the run. Should be run at the beginning of your experiment. Stores the
1112
+ hyperparameters in a yaml file for future use.
1113
+
1114
+ Args:
1115
+ values (Dictionary `str` to `bool`, `str`, `float`, `int`, or a List or Dict of those types):
1116
+ Values to be stored as initial hyperparameters as key-value pairs. The values need to have type `bool`,
1117
+ `str`, `float`, or `int`.
1118
+ """
1119
+ self.live.log_params(values)
1120
+
1121
+ @on_main_process
1122
+ def log(self, values: dict, step: Optional[int] = None, **kwargs):
1123
+ """
1124
+ Logs `values` to the current run.
1125
+
1126
+ Args:
1127
+ values (Dictionary `str` to `str`, `float`, or `int`):
1128
+ Values to be logged as key-value pairs. The values need to have type `str`, `float`, or `int`.
1129
+ step (`int`, *optional*):
1130
+ The run step. If included, the log will be affiliated with this step.
1131
+ kwargs:
1132
+ Additional key word arguments passed along to `dvclive.Live.log_metric()`.
1133
+ """
1134
+ from dvclive.plots import Metric
1135
+
1136
+ if step is not None:
1137
+ self.live.step = step
1138
+ for k, v in values.items():
1139
+ if Metric.could_log(v):
1140
+ self.live.log_metric(k, v, **kwargs)
1141
+ else:
1142
+ logger.warning_once(
1143
+ "Accelerator attempted to log a value of "
1144
+ f'"{v}" of type {type(v)} for key "{k}" as a scalar. '
1145
+ "This invocation of DVCLive's Live.log_metric() "
1146
+ "is incorrect so we dropped this attribute."
1147
+ )
1148
+ self.live.next_step()
1149
+
1150
+ @on_main_process
1151
+ def finish(self):
1152
+ """
1153
+ Closes `dvclive.Live()`.
1154
+ """
1155
+ self.live.end()
1156
+
1157
+
1158
+ class SwanLabTracker(GeneralTracker):
1159
+ """
1160
+ A `Tracker` class that supports `swanlab`. Should be initialized at the start of your script.
1161
+
1162
+ Args:
1163
+ run_name (`str`):
1164
+ The name of the experiment run.
1165
+ **kwargs (additional keyword arguments, *optional*):
1166
+ Additional key word arguments passed along to the `swanlab.init` method.
1167
+ """
1168
+
1169
+ name = "swanlab"
1170
+ requires_logging_directory = False
1171
+ main_process_only = False
1172
+
1173
+ def __init__(self, run_name: str, **kwargs):
1174
+ super().__init__()
1175
+ self.run_name = run_name
1176
+ self.init_kwargs = kwargs
1177
+
1178
+ @on_main_process
1179
+ def start(self):
1180
+ import swanlab
1181
+
1182
+ self.run = swanlab.init(project=self.run_name, **self.init_kwargs)
1183
+ swanlab.config["FRAMEWORK"] = "🤗Accelerate" # add accelerate logo in config
1184
+ logger.debug(f"Initialized SwanLab project {self.run_name}")
1185
+ logger.debug(
1186
+ "Make sure to log any initial configurations with `self.store_init_configuration` before training!"
1187
+ )
1188
+
1189
+ @property
1190
+ def tracker(self):
1191
+ return self.run
1192
+
1193
+ @on_main_process
1194
+ def store_init_configuration(self, values: dict):
1195
+ """
1196
+ Logs `values` as hyperparameters for the run. Should be run at the beginning of your experiment.
1197
+
1198
+ Args:
1199
+ values (Dictionary `str` to `bool`, `str`, `float` or `int`):
1200
+ Values to be stored as initial hyperparameters as key-value pairs. The values need to have type `bool`,
1201
+ `str`, `float`, `int`, or `None`.
1202
+ """
1203
+ import swanlab
1204
+
1205
+ swanlab.config.update(values, allow_val_change=True)
1206
+ logger.debug("Stored initial configuration hyperparameters to SwanLab")
1207
+
1208
+ @on_main_process
1209
+ def log(self, values: dict, step: Optional[int] = None, **kwargs):
1210
+ """
1211
+ Logs `values` to the current run.
1212
+
1213
+ Args:
1214
+ data : Dict[str, DataType]
1215
+ Data must be a dict. The key must be a string with 0-9, a-z, A-Z, " ", "_", "-", "/". The value must be a
1216
+ `float`, `float convertible object`, `int` or `swanlab.data.BaseType`.
1217
+ step : int, optional
1218
+ The step number of the current data, if not provided, it will be automatically incremented.
1219
+ If step is duplicated, the data will be ignored.
1220
+ kwargs:
1221
+ Additional key word arguments passed along to the `swanlab.log` method. Likes:
1222
+ print_to_console : bool, optional
1223
+ Whether to print the data to the console, the default is False.
1224
+ """
1225
+ self.run.log(values, step=step, **kwargs)
1226
+ logger.debug("Successfully logged to SwanLab")
1227
+
1228
+ @on_main_process
1229
+ def log_images(self, values: dict, step: Optional[int] = None, **kwargs):
1230
+ """
1231
+ Logs `images` to the current run.
1232
+
1233
+ Args:
1234
+ values (Dictionary `str` to `List` of `np.ndarray` or `PIL.Image`):
1235
+ Values to be logged as key-value pairs. The values need to have type `List` of `np.ndarray` or
1236
+ step (`int`, *optional*):
1237
+ The run step. If included, the log will be affiliated with this step.
1238
+ kwargs:
1239
+ Additional key word arguments passed along to the `swanlab.log` method. Likes:
1240
+ print_to_console : bool, optional
1241
+ Whether to print the data to the console, the default is False.
1242
+ """
1243
+ import swanlab
1244
+
1245
+ for k, v in values.items():
1246
+ self.log({k: [swanlab.Image(image) for image in v]}, step=step, **kwargs)
1247
+ logger.debug("Successfully logged images to SwanLab")
1248
+
1249
+ @on_main_process
1250
+ def finish(self):
1251
+ """
1252
+ Closes `swanlab` writer
1253
+ """
1254
+ self.run.finish()
1255
+ logger.debug("SwanLab run closed")
1256
+
1257
+
1258
+ LOGGER_TYPE_TO_CLASS = {
1259
+ "aim": AimTracker,
1260
+ "comet_ml": CometMLTracker,
1261
+ "mlflow": MLflowTracker,
1262
+ "tensorboard": TensorBoardTracker,
1263
+ "wandb": WandBTracker,
1264
+ "clearml": ClearMLTracker,
1265
+ "dvclive": DVCLiveTracker,
1266
+ "swanlab": SwanLabTracker,
1267
+ "trackio": TrackioTracker,
1268
+ }
1269
+
1270
+
1271
+ def filter_trackers(
1272
+ log_with: list[Union[str, LoggerType, GeneralTracker]],
1273
+ logging_dir: Optional[Union[str, os.PathLike]] = None,
1274
+ ):
1275
+ """
1276
+ Takes in a list of potential tracker types and checks that:
1277
+ - The tracker wanted is available in that environment
1278
+ - Filters out repeats of tracker types
1279
+ - If `all` is in `log_with`, will return all trackers in the environment
1280
+ - If a tracker requires a `logging_dir`, ensures that `logging_dir` is not `None`
1281
+
1282
+ Args:
1283
+ log_with (list of `str`, [`~utils.LoggerType`] or [`~tracking.GeneralTracker`], *optional*):
1284
+ A list of loggers to be setup for experiment tracking. Should be one or several of:
1285
+
1286
+ - `"all"`
1287
+ - `"tensorboard"`
1288
+ - `"wandb"`
1289
+ - `"trackio"`
1290
+ - `"aim"`
1291
+ - `"comet_ml"`
1292
+ - `"mlflow"`
1293
+ - `"dvclive"`
1294
+ - `"swanlab"`
1295
+ If `"all"` is selected, will pick up all available trackers in the environment and initialize them. Can
1296
+ also accept implementations of `GeneralTracker` for custom trackers, and can be combined with `"all"`.
1297
+ logging_dir (`str`, `os.PathLike`, *optional*):
1298
+ A path to a directory for storing logs of locally-compatible loggers.
1299
+ """
1300
+ loggers = []
1301
+ if log_with is not None:
1302
+ if not isinstance(log_with, (list, tuple)):
1303
+ log_with = [log_with]
1304
+ if "all" in log_with or LoggerType.ALL in log_with:
1305
+ loggers = [o for o in log_with if issubclass(type(o), GeneralTracker)] + get_available_trackers()
1306
+ else:
1307
+ for log_type in log_with:
1308
+ if log_type not in LoggerType and not issubclass(type(log_type), GeneralTracker):
1309
+ raise ValueError(f"Unsupported logging capability: {log_type}. Choose between {LoggerType.list()}")
1310
+ if issubclass(type(log_type), GeneralTracker):
1311
+ loggers.append(log_type)
1312
+ else:
1313
+ log_type = LoggerType(log_type)
1314
+ if log_type not in loggers:
1315
+ if log_type in get_available_trackers():
1316
+ tracker_init = LOGGER_TYPE_TO_CLASS[str(log_type)]
1317
+ if tracker_init.requires_logging_directory:
1318
+ if logging_dir is None:
1319
+ raise ValueError(
1320
+ f"Logging with `{log_type}` requires a `logging_dir` to be passed in."
1321
+ )
1322
+ loggers.append(log_type)
1323
+ else:
1324
+ logger.debug(f"Tried adding logger {log_type}, but package is unavailable in the system.")
1325
+
1326
+ return loggers