koichi12 commited on
Commit
d5967d1
·
verified ·
1 Parent(s): e549173

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .venv/lib/python3.11/site-packages/ray/train/__pycache__/__init__.cpython-311.pyc +0 -0
  2. .venv/lib/python3.11/site-packages/ray/train/__pycache__/backend.cpython-311.pyc +0 -0
  3. .venv/lib/python3.11/site-packages/ray/train/__pycache__/base_trainer.cpython-311.pyc +0 -0
  4. .venv/lib/python3.11/site-packages/ray/train/__pycache__/constants.cpython-311.pyc +0 -0
  5. .venv/lib/python3.11/site-packages/ray/train/__pycache__/context.cpython-311.pyc +0 -0
  6. .venv/lib/python3.11/site-packages/ray/train/__pycache__/data_parallel_trainer.cpython-311.pyc +0 -0
  7. .venv/lib/python3.11/site-packages/ray/train/__pycache__/error.cpython-311.pyc +0 -0
  8. .venv/lib/python3.11/site-packages/ray/train/__pycache__/predictor.cpython-311.pyc +0 -0
  9. .venv/lib/python3.11/site-packages/ray/train/__pycache__/session.cpython-311.pyc +0 -0
  10. .venv/lib/python3.11/site-packages/ray/train/__pycache__/utils.cpython-311.pyc +0 -0
  11. .venv/lib/python3.11/site-packages/ray/train/_internal/__init__.py +0 -0
  12. .venv/lib/python3.11/site-packages/ray/train/_internal/__pycache__/__init__.cpython-311.pyc +0 -0
  13. .venv/lib/python3.11/site-packages/ray/train/_internal/__pycache__/accelerator.cpython-311.pyc +0 -0
  14. .venv/lib/python3.11/site-packages/ray/train/_internal/__pycache__/backend_executor.cpython-311.pyc +0 -0
  15. .venv/lib/python3.11/site-packages/ray/train/_internal/__pycache__/checkpoint_manager.cpython-311.pyc +0 -0
  16. .venv/lib/python3.11/site-packages/ray/train/_internal/__pycache__/data_config.cpython-311.pyc +0 -0
  17. .venv/lib/python3.11/site-packages/ray/train/_internal/__pycache__/dl_predictor.cpython-311.pyc +0 -0
  18. .venv/lib/python3.11/site-packages/ray/train/_internal/__pycache__/framework_checkpoint.cpython-311.pyc +0 -0
  19. .venv/lib/python3.11/site-packages/ray/train/_internal/__pycache__/session.cpython-311.pyc +0 -0
  20. .venv/lib/python3.11/site-packages/ray/train/_internal/__pycache__/storage.cpython-311.pyc +0 -0
  21. .venv/lib/python3.11/site-packages/ray/train/_internal/__pycache__/syncer.cpython-311.pyc +0 -0
  22. .venv/lib/python3.11/site-packages/ray/train/_internal/__pycache__/utils.cpython-311.pyc +0 -0
  23. .venv/lib/python3.11/site-packages/ray/train/_internal/__pycache__/worker_group.cpython-311.pyc +0 -0
  24. .venv/lib/python3.11/site-packages/ray/train/_internal/accelerator.py +5 -0
  25. .venv/lib/python3.11/site-packages/ray/train/_internal/backend_executor.py +830 -0
  26. .venv/lib/python3.11/site-packages/ray/train/_internal/checkpoint_manager.py +185 -0
  27. .venv/lib/python3.11/site-packages/ray/train/_internal/data_config.py +139 -0
  28. .venv/lib/python3.11/site-packages/ray/train/_internal/dl_predictor.py +103 -0
  29. .venv/lib/python3.11/site-packages/ray/train/_internal/framework_checkpoint.py +45 -0
  30. .venv/lib/python3.11/site-packages/ray/train/_internal/session.py +1163 -0
  31. .venv/lib/python3.11/site-packages/ray/train/_internal/state/__init__.py +14 -0
  32. .venv/lib/python3.11/site-packages/ray/train/_internal/state/__pycache__/__init__.cpython-311.pyc +0 -0
  33. .venv/lib/python3.11/site-packages/ray/train/_internal/state/__pycache__/schema.cpython-311.pyc +0 -0
  34. .venv/lib/python3.11/site-packages/ray/train/_internal/state/__pycache__/state_actor.cpython-311.pyc +0 -0
  35. .venv/lib/python3.11/site-packages/ray/train/_internal/state/__pycache__/state_manager.cpython-311.pyc +0 -0
  36. .venv/lib/python3.11/site-packages/ray/train/_internal/state/schema.py +158 -0
  37. .venv/lib/python3.11/site-packages/ray/train/_internal/state/state_actor.py +62 -0
  38. .venv/lib/python3.11/site-packages/ray/train/_internal/state/state_manager.py +126 -0
  39. .venv/lib/python3.11/site-packages/ray/train/_internal/storage.py +725 -0
  40. .venv/lib/python3.11/site-packages/ray/train/_internal/syncer.py +490 -0
  41. .venv/lib/python3.11/site-packages/ray/train/_internal/utils.py +239 -0
  42. .venv/lib/python3.11/site-packages/ray/train/_internal/worker_group.py +426 -0
  43. .venv/lib/python3.11/site-packages/ray/train/horovod/__init__.py +22 -0
  44. .venv/lib/python3.11/site-packages/ray/train/horovod/__pycache__/__init__.cpython-311.pyc +0 -0
  45. .venv/lib/python3.11/site-packages/ray/train/horovod/__pycache__/config.cpython-311.pyc +0 -0
  46. .venv/lib/python3.11/site-packages/ray/train/horovod/__pycache__/horovod_trainer.cpython-311.pyc +0 -0
  47. .venv/lib/python3.11/site-packages/ray/train/horovod/config.py +159 -0
  48. .venv/lib/python3.11/site-packages/ray/train/horovod/horovod_trainer.py +202 -0
  49. .venv/lib/python3.11/site-packages/ray/train/lightning/__init__.py +39 -0
  50. .venv/lib/python3.11/site-packages/ray/train/lightning/__pycache__/__init__.cpython-311.pyc +0 -0
.venv/lib/python3.11/site-packages/ray/train/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (2.69 kB). View file
 
.venv/lib/python3.11/site-packages/ray/train/__pycache__/backend.cpython-311.pyc ADDED
Binary file (3.3 kB). View file
 
.venv/lib/python3.11/site-packages/ray/train/__pycache__/base_trainer.cpython-311.pyc ADDED
Binary file (37.6 kB). View file
 
.venv/lib/python3.11/site-packages/ray/train/__pycache__/constants.cpython-311.pyc ADDED
Binary file (3.1 kB). View file
 
.venv/lib/python3.11/site-packages/ray/train/__pycache__/context.cpython-311.pyc ADDED
Binary file (7.62 kB). View file
 
.venv/lib/python3.11/site-packages/ray/train/__pycache__/data_parallel_trainer.cpython-311.pyc ADDED
Binary file (26.8 kB). View file
 
.venv/lib/python3.11/site-packages/ray/train/__pycache__/error.cpython-311.pyc ADDED
Binary file (671 Bytes). View file
 
.venv/lib/python3.11/site-packages/ray/train/__pycache__/predictor.cpython-311.pyc ADDED
Binary file (12.6 kB). View file
 
.venv/lib/python3.11/site-packages/ray/train/__pycache__/session.cpython-311.pyc ADDED
Binary file (181 Bytes). View file
 
.venv/lib/python3.11/site-packages/ray/train/__pycache__/utils.cpython-311.pyc ADDED
Binary file (896 Bytes). View file
 
.venv/lib/python3.11/site-packages/ray/train/_internal/__init__.py ADDED
File without changes
.venv/lib/python3.11/site-packages/ray/train/_internal/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (192 Bytes). View file
 
.venv/lib/python3.11/site-packages/ray/train/_internal/__pycache__/accelerator.cpython-311.pyc ADDED
Binary file (551 Bytes). View file
 
.venv/lib/python3.11/site-packages/ray/train/_internal/__pycache__/backend_executor.cpython-311.pyc ADDED
Binary file (36.6 kB). View file
 
.venv/lib/python3.11/site-packages/ray/train/_internal/__pycache__/checkpoint_manager.cpython-311.pyc ADDED
Binary file (8.59 kB). View file
 
.venv/lib/python3.11/site-packages/ray/train/_internal/__pycache__/data_config.cpython-311.pyc ADDED
Binary file (6.97 kB). View file
 
.venv/lib/python3.11/site-packages/ray/train/_internal/__pycache__/dl_predictor.cpython-311.pyc ADDED
Binary file (5.64 kB). View file
 
.venv/lib/python3.11/site-packages/ray/train/_internal/__pycache__/framework_checkpoint.cpython-311.pyc ADDED
Binary file (2.52 kB). View file
 
.venv/lib/python3.11/site-packages/ray/train/_internal/__pycache__/session.cpython-311.pyc ADDED
Binary file (46.4 kB). View file
 
.venv/lib/python3.11/site-packages/ray/train/_internal/__pycache__/storage.cpython-311.pyc ADDED
Binary file (36.2 kB). View file
 
.venv/lib/python3.11/site-packages/ray/train/_internal/__pycache__/syncer.cpython-311.pyc ADDED
Binary file (23.8 kB). View file
 
.venv/lib/python3.11/site-packages/ray/train/_internal/__pycache__/utils.cpython-311.pyc ADDED
Binary file (11.5 kB). View file
 
.venv/lib/python3.11/site-packages/ray/train/_internal/__pycache__/worker_group.cpython-311.pyc ADDED
Binary file (21.4 kB). View file
 
.venv/lib/python3.11/site-packages/ray/train/_internal/accelerator.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ import abc
2
+
3
+
4
+ class Accelerator(abc.ABC):
5
+ """A utility that contains methods to accelerate training."""
.venv/lib/python3.11/site-packages/ray/train/_internal/backend_executor.py ADDED
@@ -0,0 +1,830 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import time
4
+ from collections import defaultdict
5
+ from dataclasses import dataclass
6
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar
7
+
8
+ import ray
9
+ import ray._private.ray_constants as ray_constants
10
+ from ray._private.ray_constants import env_integer
11
+ from ray.data import Dataset
12
+ from ray.exceptions import RayActorError
13
+ from ray.train import Checkpoint, DataConfig
14
+ from ray.train._internal.session import (
15
+ TrialInfo,
16
+ _TrainingResult,
17
+ get_session,
18
+ init_session,
19
+ shutdown_session,
20
+ )
21
+ from ray.train._internal.storage import StorageContext
22
+ from ray.train._internal.utils import check_for_failure
23
+ from ray.train._internal.worker_group import WorkerGroup
24
+ from ray.train.backend import BackendConfig
25
+ from ray.train.constants import (
26
+ ENABLE_DETAILED_AUTOFILLED_METRICS_ENV,
27
+ ENABLE_SHARE_CUDA_VISIBLE_DEVICES_ENV,
28
+ ENABLE_SHARE_NEURON_CORES_ACCELERATOR_ENV,
29
+ ENABLE_SHARE_NPU_RT_VISIBLE_DEVICES_ENV,
30
+ ENABLE_SHARE_ROCR_VISIBLE_DEVICES_ENV,
31
+ RAY_TRAIN_ENABLE_STATE_TRACKING,
32
+ TRAIN_ENABLE_WORKER_SPREAD_ENV,
33
+ TRAIN_PLACEMENT_GROUP_TIMEOUT_S_ENV,
34
+ )
35
+ from ray.util.placement_group import get_current_placement_group, remove_placement_group
36
+
37
+ T = TypeVar("T")
38
+
39
+ logger = logging.getLogger(__name__)
40
+
41
+
42
+ class TrainBackendError(Exception):
43
+ """Errors with BackendExecutor that should not be exposed to user."""
44
+
45
+
46
+ class TrainingWorkerError(Exception):
47
+ """Raised if a worker fails during training."""
48
+
49
+
50
+ @dataclass
51
+ class ResourceConfig:
52
+ """
53
+ Resource configuration for resource_ids to share between workers.
54
+
55
+ Args:
56
+ resource_name: The name of the resource to configure
57
+ (Example: "neuron_cores" or "gpu").
58
+ resource_enable_sharing_env_var: The environment variable to
59
+ check if the resource should be shared.
60
+ share_resource_ids_env_var: The environment variable to configure for
61
+ sharing the resources with other workers.
62
+ """
63
+
64
+ resource_name: str
65
+ resource_enable_sharing_env_var: str
66
+ share_resource_ids_env_var: str
67
+
68
+
69
+ class BackendExecutor:
70
+ """Main execution class for training backends.
71
+
72
+ This class holds a worker group and is responsible for executing the
73
+ training function on the workers, and collecting intermediate results
74
+ from ``session.report()``.
75
+
76
+ Args:
77
+ backend_config: The configurations for this
78
+ specific backend.
79
+ num_workers: Number of workers to use for training.
80
+ resources_per_worker (Optional[Dict[str, float]]):
81
+ Dictionary specifying the resources that will be
82
+ requested for each worker. Defaults to {"CPU": 1}.
83
+ max_retries: Number of retries when Ray actors fail.
84
+ Defaults to 3. Set to -1 for unlimited retries.
85
+ """
86
+
87
+ def __init__(
88
+ self,
89
+ backend_config: BackendConfig,
90
+ # TODO(xwjiang): Legacy Ray Train trainer clean up!
91
+ trial_info: Optional[TrialInfo] = None,
92
+ num_workers: int = 1,
93
+ resources_per_worker: Optional[Dict[str, float]] = None,
94
+ max_retries: int = 3,
95
+ ):
96
+ if resources_per_worker is None:
97
+ self._resources_per_worker = {"CPU": 1}
98
+ else:
99
+ self._resources_per_worker = resources_per_worker.copy()
100
+
101
+ self._backend_config = backend_config
102
+ self._backend = backend_config.backend_cls()
103
+ self._num_workers = num_workers
104
+ self._max_failures = max_retries
105
+ if self._max_failures < 0:
106
+ self._max_failures = float("inf")
107
+ self._num_failures = 0
108
+ self._last_failure = None
109
+ self._initialization_hook = None
110
+ self._placement_group = None
111
+
112
+ self._trial_info = trial_info
113
+
114
+ self.worker_group = InactiveWorkerGroup()
115
+ self.dataset_shards = None
116
+
117
+ self._resource_configs = [
118
+ ResourceConfig(
119
+ ray_constants.NEURON_CORES,
120
+ ENABLE_SHARE_NEURON_CORES_ACCELERATOR_ENV,
121
+ ray_constants.NEURON_RT_VISIBLE_CORES_ENV_VAR,
122
+ ),
123
+ ResourceConfig(
124
+ ray_constants.NPU,
125
+ ENABLE_SHARE_NPU_RT_VISIBLE_DEVICES_ENV,
126
+ ray_constants.NPU_RT_VISIBLE_DEVICES_ENV_VAR,
127
+ ),
128
+ # For AMD GPUs, they are using ROCR_VISIBLE_DEVICES env var.
129
+ ResourceConfig(
130
+ ray_constants.GPU,
131
+ ENABLE_SHARE_ROCR_VISIBLE_DEVICES_ENV,
132
+ ray_constants.ROCR_VISIBLE_DEVICES_ENV_VAR,
133
+ ),
134
+ ]
135
+
136
+ # Record the initialization time of BackendExecutor, which is
137
+ # after trainer.fit() and before worker_group executes the training function.
138
+ self._start_time_ms = int(time.time() * 1000)
139
+
140
+ self.state_tracking_enabled = env_integer(RAY_TRAIN_ENABLE_STATE_TRACKING, 0)
141
+
142
+ def start(
143
+ self,
144
+ initialization_hook: Optional[Callable[[], None]] = None,
145
+ train_cls: Optional[Type] = None,
146
+ train_cls_args: Optional[Tuple] = None,
147
+ train_cls_kwargs: Optional[Dict] = None,
148
+ ):
149
+ """Starts the worker group."""
150
+ self._create_placement_group()
151
+ placement_group = self._placement_group or "default"
152
+ self.worker_group = WorkerGroup(
153
+ num_workers=self._num_workers,
154
+ resources_per_worker=self._resources_per_worker,
155
+ actor_cls=train_cls,
156
+ actor_cls_args=train_cls_args,
157
+ actor_cls_kwargs=train_cls_kwargs,
158
+ placement_group=placement_group,
159
+ )
160
+ # Hack to avoid OOMs.
161
+ # This is just a temporary solution for Train loading entire checkpoints
162
+ # into memory by ensuring that the rank 0 worker is on the same node as
163
+ # trainable, thus allowing for lazy checkpoint transfer to be used.
164
+ # See https://github.com/ray-project/ray/issues/33073
165
+ # for more context.
166
+ # TODO remove passing in trial_driver_ip.
167
+
168
+ trial_driver_node_id = (
169
+ self._trial_info.driver_node_id if self._trial_info else None
170
+ )
171
+ self.worker_group.sort_workers_by_node_id_and_gpu_id(trial_driver_node_id)
172
+
173
+ try:
174
+ if initialization_hook:
175
+ self._initialization_hook = initialization_hook
176
+ self.worker_group.execute(initialization_hook)
177
+
178
+ # Always propagate the driver's DataContext to each worker in the group.
179
+ from ray.data import DataContext
180
+
181
+ def _set_driver_dataset_context(ctx: DataContext):
182
+ DataContext._set_current(ctx)
183
+
184
+ self.worker_group.execute(
185
+ _set_driver_dataset_context,
186
+ DataContext.get_current(),
187
+ )
188
+
189
+ share_cuda_visible_devices_enabled = bool(
190
+ env_integer(
191
+ ENABLE_SHARE_CUDA_VISIBLE_DEVICES_ENV,
192
+ self._backend.share_cuda_visible_devices,
193
+ )
194
+ )
195
+
196
+ if (
197
+ self._resources_per_worker.get("GPU", 0) > 0
198
+ and share_cuda_visible_devices_enabled
199
+ ):
200
+ self._share_cuda_visible_devices()
201
+ for resource_config in self._resource_configs:
202
+ if self._is_share_resources_enabled(
203
+ resource_config.resource_name,
204
+ resource_config.resource_enable_sharing_env_var,
205
+ ):
206
+ self._share_resource_ids(
207
+ resource_config.resource_name,
208
+ resource_config.share_resource_ids_env_var,
209
+ )
210
+ self._backend.on_start(self.worker_group, self._backend_config)
211
+ except RayActorError as exc:
212
+ logger.exception(str(exc))
213
+ logger.warning(
214
+ "Failure occurred during startup. Restarting all workers and "
215
+ "attempting to startup again."
216
+ )
217
+ self._increment_failures()
218
+ self._restart()
219
+
220
+ if self.state_tracking_enabled:
221
+ from ray.train._internal.state import TrainRunStateManager
222
+ from ray.train._internal.state.state_actor import get_state_actor
223
+
224
+ self.state_manager = TrainRunStateManager(state_actor=get_state_actor())
225
+
226
+ def _create_placement_group(self):
227
+ """Creates a placement group if it does not exist.
228
+
229
+ If a placement group is already detected (Tune) this will be a no-op.
230
+
231
+ By default the placement group will be created with PACK strategy.
232
+ This is optimized for colocating GPUs on a minimal number of nodes.
233
+ This behavior can be overridden to use the SPREAD strategy by defining
234
+ ``TRAIN_ENABLE_WORKER_SPREAD_ENV``
235
+
236
+ If a placement group is created it will be stored as
237
+ self._placement_group.
238
+ """
239
+ current_placement_group = get_current_placement_group()
240
+ worker = ray._private.worker.global_worker
241
+ should_capture_child_tasks_in_placement_group = (
242
+ worker.should_capture_child_tasks_in_placement_group
243
+ )
244
+ should_create_placement_group = (
245
+ current_placement_group is None
246
+ or not should_capture_child_tasks_in_placement_group
247
+ )
248
+
249
+ if should_create_placement_group:
250
+ bundles = [
251
+ self._resources_per_worker.copy() for _ in range(self._num_workers)
252
+ ]
253
+
254
+ use_spread = bool(env_integer(TRAIN_ENABLE_WORKER_SPREAD_ENV, 0))
255
+ strategy = "SPREAD" if use_spread else "PACK"
256
+
257
+ placement_group = ray.util.placement_group(bundles, strategy=strategy)
258
+ logger.debug("Waiting for placement group to start.")
259
+ timeout = env_integer(TRAIN_PLACEMENT_GROUP_TIMEOUT_S_ENV, 100)
260
+ ready, _ = ray.wait([placement_group.ready()], timeout=timeout)
261
+ if ready:
262
+ logger.debug("Placement group has started.")
263
+ else:
264
+ raise TimeoutError(
265
+ "Placement group creation timed out. Make sure your "
266
+ "cluster either has enough resources or use an "
267
+ "autoscaling cluster. If you are running on a cluster, "
268
+ "make sure you specify an address in `ray.init()`, for example, "
269
+ '`ray.init("auto")`. You can also increase the timeout by setting '
270
+ "the TRAIN_PLACEMENT_GROUP_TIMEOUT_S environment variable. "
271
+ "Current resources available: {}, resources requested by the "
272
+ "placement group: {}".format(
273
+ ray.available_resources(), placement_group.bundle_specs
274
+ )
275
+ )
276
+ self._placement_group = placement_group
277
+
278
+ def _share_cuda_visible_devices(self):
279
+ """Sets CUDA_VISIBLE_DEVICES on all workers.
280
+
281
+ For each worker, CUDA_VISIBLE_DEVICES will be set to the GPU IDs
282
+ visible to all workers on that worker's node.
283
+
284
+ This allows GPU workers on the same node to communicate with one
285
+ another.
286
+
287
+ Example:
288
+
289
+ Setup:
290
+ - Node1:
291
+ - Worker1: {0, 1}
292
+ - Worker2: {2, 3}
293
+ - Node2:
294
+ - Worker3: {0, 1}
295
+
296
+ CUDA_VISIBLE_DEVICES:
297
+ - Worker1: "0,1,2,3"
298
+ - Worker2: "0,1,2,3"
299
+ - Worker3: "0,1"
300
+
301
+ """
302
+ self._share_resource_ids(
303
+ ray_constants.GPU, ray_constants.CUDA_VISIBLE_DEVICES_ENV_VAR
304
+ )
305
+
306
+ def _share_resource_ids(self, resource: str, env_var: str):
307
+ """Sets the given env_var on all workers.
308
+
309
+ For each worker, the cores/devices are visible to all the
310
+ workers on that worker's node.This allows workers on the
311
+ same node to communicate with one another.
312
+
313
+ Example:
314
+
315
+ Setup:
316
+ - Node1:
317
+ - Worker1: {0, 1}
318
+ - Worker2: {2, 3}
319
+ - Node2:
320
+ - Worker3: {0, 1}
321
+
322
+ NEURON_RT_VISIBLE_CORES/TPU_VISIBLE_CHIPS/...:
323
+ - Worker1: "0,1,2,3"
324
+ - Worker2: "0,1,2,3"
325
+ - Worker2: "0,1"
326
+
327
+ Args:
328
+ resource: The name of the resource/accelerator.
329
+ env_var: The name of the environment variable to set.
330
+ """
331
+ node_ids_and_resource_ids = [
332
+ (
333
+ w.metadata.node_id,
334
+ w.metadata.resource_ids[resource],
335
+ )
336
+ for w in self.worker_group.workers
337
+ ]
338
+ node_id_to_worker_id = defaultdict(set)
339
+ node_id_to_resource_ids = defaultdict(set)
340
+
341
+ for worker_id, (node_id, resource_ids) in enumerate(node_ids_and_resource_ids):
342
+ node_id_to_worker_id[node_id].add(worker_id)
343
+ node_id_to_resource_ids[node_id].update(resource_ids)
344
+
345
+ futures = []
346
+ for node_id, resource_ids in node_id_to_resource_ids.items():
347
+ resource_ids = sorted(resource_ids)
348
+ all_resource_ids = ",".join(resource_ids)
349
+
350
+ def set_resource_ids():
351
+ os.environ[env_var] = all_resource_ids
352
+
353
+ for worker_id in node_id_to_worker_id[node_id]:
354
+ futures.append(
355
+ self.worker_group.execute_single_async(worker_id, set_resource_ids)
356
+ )
357
+ ray.get(futures)
358
+
359
+ def _is_share_resources_enabled(self, resource_name: str, enable_sharing_env: str):
360
+ """Whether to share resource IDs on all workers
361
+ based on enable_sharing_env.
362
+
363
+ This will return true if resources are requested and greater than 0.
364
+ Also, user can disable by configuring the `enable_sharing_env` to "0".
365
+
366
+ Args:
367
+ resource_name: The name of the resource/accelerator.
368
+ enable_sharing_env: The name of the environment variable
369
+ to check.
370
+ """
371
+ has_resource_requested = self._resources_per_worker.get(resource_name, 0) > 0
372
+ return has_resource_requested and ray_constants.env_bool(
373
+ enable_sharing_env, True
374
+ )
375
+
376
+ def _create_rank_world_size_mappings(self) -> List[Dict]:
377
+ """Create rank and world size mappings for workers.
378
+ There are three maps returned:
379
+ - local_rank_map, which maps from worker world_rank to local_rank.
380
+ - local_world_size_map, which maps from world_rank to local_world_size
381
+ - node_rank_map, which maps from world rank to node rank
382
+
383
+ Example:
384
+ Worker 0: node 0
385
+ Worker 1: node 0
386
+ Worker 2: node 1
387
+ Worker 3: node 0
388
+ Worker 4: node 1
389
+
390
+ Workers 0, 1, 3 are on node 0.
391
+ Workers 2, 4 are on node 1.
392
+
393
+ Expected local_rank_map:
394
+ {
395
+ 0 -> 0,
396
+ 1 -> 1,
397
+ 2 -> 0,
398
+ 3 -> 2,
399
+ 4 -> 1
400
+ }
401
+
402
+ Expected local_world_size_map:
403
+ {
404
+ 0 -> 3,
405
+ 1 -> 3,
406
+ 2 -> 2,
407
+ 3 -> 3,
408
+ 4 -> 2
409
+ }
410
+
411
+ Expected node_rank_map:
412
+ {
413
+ 0 -> 0,
414
+ 1 -> 0,
415
+ 2 -> 1,
416
+ 3 -> 0,
417
+ 4 -> 1
418
+ }
419
+
420
+ """
421
+ local_rank_map = {} # map from world rank to local rank
422
+ local_world_size_map = {} # map from world rank to local world size
423
+ node_rank_map = {} # map from world rank to node rank
424
+ node_ids = {} # map from node id to node index
425
+ node_cnt = 0 # count the number of nodes
426
+
427
+ node_id_dict = defaultdict(
428
+ int
429
+ ) # map from node id to the number of workers on it.
430
+ for world_rank in range(len(self.worker_group)):
431
+ worker = self.worker_group.workers[world_rank]
432
+ node_id = worker.metadata.node_id
433
+ local_rank_map[world_rank] = node_id_dict[node_id]
434
+ node_id_dict[node_id] += 1
435
+
436
+ if node_id not in node_ids:
437
+ node_ids[node_id] = node_cnt
438
+ node_cnt += 1
439
+ node_rank_map[world_rank] = node_ids[node_id]
440
+
441
+ for world_rank in range(len(self.worker_group)):
442
+ worker = self.worker_group.workers[world_rank]
443
+ node_id = worker.metadata.node_id
444
+ local_world_size_map[world_rank] = node_id_dict[node_id]
445
+
446
+ workers_info = "\n".join(
447
+ [
448
+ f"- (node_id={w.metadata.node_id}, ip={w.metadata.node_ip}, "
449
+ f"pid={w.metadata.pid}) world_rank={i}, "
450
+ f"local_rank={local_rank_map[i]}, node_rank={node_rank_map[i]}"
451
+ for i, w in enumerate(self.worker_group.workers)
452
+ ]
453
+ )
454
+ logger.info(f"Started distributed worker processes: \n{workers_info}")
455
+
456
+ return local_rank_map, local_world_size_map, node_rank_map
457
+
458
+ def start_training(
459
+ self,
460
+ train_func: Callable[[], T],
461
+ datasets: Dict[str, Dataset],
462
+ metadata: Dict[str, Any],
463
+ data_config: DataConfig,
464
+ storage: StorageContext,
465
+ checkpoint: Optional[Checkpoint] = None,
466
+ ) -> None:
467
+ """Executes a training function on all workers in a separate thread.
468
+
469
+ ``finish_training`` should be called after this.
470
+
471
+ Args:
472
+ train_func: The training function to run on each worker.
473
+ datasets: The base datasets.
474
+ data_config: The config object for creating dataset shards for workers.
475
+ checkpoint: The checkpoint data that
476
+ should be loaded onto each worker and accessed by the
477
+ training function via ``session.get_checkpoint()``. If this
478
+ is ``None`` then no checkpoint will be loaded.
479
+ """
480
+ use_detailed_autofilled_metrics = env_integer(
481
+ ENABLE_DETAILED_AUTOFILLED_METRICS_ENV, 0
482
+ )
483
+
484
+ # First initialize the session.
485
+ def initialize_session(
486
+ train_func,
487
+ world_rank,
488
+ local_rank,
489
+ node_rank,
490
+ local_world_size,
491
+ world_size,
492
+ trial_info,
493
+ checkpoint,
494
+ dataset_shard,
495
+ metadata,
496
+ storage,
497
+ ):
498
+ try:
499
+ init_session(
500
+ training_func=train_func,
501
+ world_rank=world_rank,
502
+ local_rank=local_rank,
503
+ node_rank=node_rank,
504
+ local_world_size=local_world_size,
505
+ world_size=world_size,
506
+ trial_info=trial_info,
507
+ dataset_shard=dataset_shard,
508
+ metadata=metadata,
509
+ checkpoint=checkpoint,
510
+ detailed_autofilled_metrics=use_detailed_autofilled_metrics,
511
+ storage=storage,
512
+ )
513
+ except ValueError:
514
+ raise TrainBackendError(
515
+ "Attempting to start training but a "
516
+ "previous training run is still ongoing. "
517
+ "You must call `finish_training` before "
518
+ "calling `start_training` again."
519
+ )
520
+
521
+ if self.dataset_shards is None:
522
+ actors = [worker.actor for worker in self.worker_group.workers]
523
+ node_ids = [worker.metadata.node_id for worker in self.worker_group.workers]
524
+ self.dataset_shards = data_config.configure(
525
+ datasets,
526
+ world_size=len(self.worker_group),
527
+ worker_handles=actors,
528
+ worker_node_ids=node_ids,
529
+ )
530
+
531
+ (
532
+ local_rank_map,
533
+ local_world_size_map,
534
+ node_rank_map,
535
+ ) = self._create_rank_world_size_mappings()
536
+
537
+ futures = []
538
+ for index in range(len(self.worker_group)):
539
+ futures.append(
540
+ self.worker_group.execute_single_async(
541
+ index,
542
+ initialize_session,
543
+ world_rank=index,
544
+ local_rank=local_rank_map[index],
545
+ node_rank=node_rank_map[index],
546
+ local_world_size=local_world_size_map[index],
547
+ world_size=len(self.worker_group),
548
+ trial_info=self._trial_info,
549
+ train_func=train_func,
550
+ dataset_shard=self.dataset_shards[index],
551
+ metadata=metadata,
552
+ checkpoint=checkpoint,
553
+ storage=storage,
554
+ )
555
+ )
556
+
557
+ self._backend.on_training_start(self.worker_group, self._backend_config)
558
+
559
+ self.get_with_failure_handling(futures)
560
+
561
+ # Register Train Run before training starts
562
+ if self.state_tracking_enabled:
563
+ from ray.train._internal.state.schema import RunStatusEnum
564
+
565
+ core_context = ray.runtime_context.get_runtime_context()
566
+
567
+ self.state_manager.register_train_run(
568
+ run_id=self._trial_info.run_id,
569
+ run_name=self._trial_info.experiment_name,
570
+ job_id=core_context.get_job_id(),
571
+ controller_actor_id=core_context.get_actor_id(),
572
+ datasets=datasets,
573
+ worker_group=self.worker_group,
574
+ start_time_ms=self._start_time_ms,
575
+ run_status=RunStatusEnum.RUNNING,
576
+ )
577
+
578
+ # Run the training function asynchronously in its own thread.
579
+ def train_async():
580
+ session = get_session()
581
+ session.start()
582
+
583
+ self.worker_group.execute_async(train_async)
584
+
585
+ def get_next_results(self) -> Optional[List[_TrainingResult]]:
586
+ """Fetches the next ``_TrainingResult`` from each worker.
587
+
588
+ Each ``_TrainingResult`` is expected to correspond to the same step from
589
+ each worker (e.g. the same call to ``train.report()``).
590
+
591
+ Returns:
592
+ A list of ``_TrainingResult``s or ``None`` if there are no more results
593
+ since the training function has exited on all workers.
594
+ """
595
+
596
+ def get_next():
597
+ session = _get_session("get_next_results")
598
+ try:
599
+ result = session.get_next()
600
+ except RuntimeError:
601
+ # Training thread has not been started yet.
602
+ raise TrainBackendError(
603
+ "`get_next_results` has been called "
604
+ "before `start_training`. Please call "
605
+ "`start_training` before "
606
+ "`get_next_results`."
607
+ )
608
+
609
+ return result
610
+
611
+ # Get next result from each worker.
612
+ futures = self.worker_group.execute_async(get_next)
613
+ results = self.get_with_failure_handling(futures)
614
+
615
+ # Check if any worker returned None.
616
+ if any(r is None for r in results):
617
+ # Either all workers have results or none of them do.
618
+ if not all(r is None for r in results):
619
+ raise RuntimeError(
620
+ "Some workers returned results while "
621
+ "others didn't. Make sure that "
622
+ "`session.report()` are called the "
623
+ "same number of times on all workers."
624
+ )
625
+ else:
626
+ # Return None if all results are None.
627
+ return None
628
+
629
+ return results
630
+
631
+ def pause_reporting(self):
632
+ """Disable workers from enqueuing results from ``session.report()``.
633
+
634
+ Note: Already reported results may still be enqueued at this point,
635
+ and should be handled appropriately.
636
+ """
637
+
638
+ def pause_session_reporting():
639
+ session = _get_session("pause_reporting")
640
+ return session.pause_reporting()
641
+
642
+ futures = self.worker_group.execute_async(pause_session_reporting)
643
+ self.get_with_failure_handling(futures)
644
+
645
+ def finish_training(self):
646
+ """Finish training and return final results. Propagate any exceptions.
647
+
648
+ Blocks until training is finished on all workers.
649
+
650
+ Assumes `start_training` has already been called.
651
+
652
+ Returns:
653
+ A list of return values from calling ``train_func`` on each worker.
654
+ Each item corresponds to the return value from a single worker.
655
+ """
656
+
657
+ def end_training():
658
+ session = _get_session("finish_training")
659
+ try:
660
+ # session.finish raises any Exceptions from training.
661
+ output = session.finish()
662
+ finally:
663
+ # Shutdown session even if session.finish() raises an
664
+ # Exception.
665
+ shutdown_session()
666
+
667
+ return output
668
+
669
+ futures = self.worker_group.execute_async(end_training)
670
+ results = self.get_with_failure_handling(futures)
671
+ return results
672
+
673
+ def report_final_run_status(
674
+ self,
675
+ errored: bool = False,
676
+ failed_rank: Optional[int] = None,
677
+ stack_trace: Optional[str] = None,
678
+ ):
679
+ """Report the final train run status, error, and end time to TrainStateActor."""
680
+ if self.state_tracking_enabled:
681
+ from ray.train._internal.state.schema import (
682
+ MAX_ERROR_STACK_TRACE_LENGTH,
683
+ RunStatusEnum,
684
+ )
685
+
686
+ if errored:
687
+ run_status = RunStatusEnum.ERRORED
688
+ status_detail = ""
689
+ if failed_rank is not None:
690
+ status_detail += f"Rank {failed_rank} worker raised an error. \n"
691
+ if stack_trace is not None:
692
+ # Keep only the last part of the stack trace if it's too long.
693
+ status_detail += stack_trace[-MAX_ERROR_STACK_TRACE_LENGTH:]
694
+ else:
695
+ run_status = RunStatusEnum.FINISHED
696
+ status_detail = ""
697
+
698
+ self.state_manager.end_train_run(
699
+ run_id=self._trial_info.run_id,
700
+ run_status=run_status,
701
+ status_detail=status_detail,
702
+ end_time_ms=int(time.time() * 1000),
703
+ )
704
+
705
+ def get_with_failure_handling(self, remote_values):
706
+ """Gets the remote values while handling for worker failures.
707
+
708
+ This method should be called instead of ``ray.get()`` directly in
709
+ order to handle worker failures.
710
+
711
+ If a worker failure is identified, backend specific failure handling
712
+ is executed and a ``TrainingWorkerError`` is raised.
713
+
714
+ Args:
715
+ remote_values: List of object refs representing functions
716
+ that may fail in the middle of execution. For example, running
717
+ a Train training loop in multiple parallel actor calls.
718
+ Returns:
719
+ The resolved objects represented by the passed in ObjectRefs.
720
+ """
721
+ success, exception = check_for_failure(remote_values)
722
+ if success:
723
+ return ray.get(remote_values)
724
+ else:
725
+ self._last_failure = exception
726
+ self._increment_failures()
727
+ logger.warning(
728
+ "Failure identified during training. Restarting all workers and "
729
+ "continuing training from latest checkpoint."
730
+ )
731
+ self._restart()
732
+ raise TrainingWorkerError
733
+
734
+ def shutdown(self, graceful_termination: bool = True):
735
+ """Shuts down the workers in the worker group.
736
+
737
+ Args:
738
+ graceful_termination: If set to True, attempt to clean up the backend
739
+ before terminating the Ray actors.
740
+
741
+ """
742
+ if graceful_termination:
743
+ try:
744
+ self._backend.on_shutdown(self.worker_group, self._backend_config)
745
+ except RayActorError:
746
+ logger.warning(
747
+ "Graceful shutdown of backend failed. This is "
748
+ "expected if one of the workers has crashed."
749
+ )
750
+
751
+ if graceful_termination:
752
+ self.worker_group.shutdown()
753
+ else:
754
+ self.worker_group.shutdown(patience_s=0)
755
+ self.worker_group = InactiveWorkerGroup()
756
+
757
+ if self._placement_group:
758
+ remove_placement_group(self._placement_group)
759
+ self._placement_group = None
760
+
761
+ self.dataset_shards = None
762
+
763
+ def is_started(self):
764
+ return not isinstance(self.worker_group, InactiveWorkerGroup)
765
+
766
+ def _restart(self):
767
+ self.worker_group.shutdown()
768
+ if self._initialization_hook is not None:
769
+ initialization_hook = self._initialization_hook
770
+ else:
771
+ initialization_hook = None
772
+ if self._placement_group:
773
+ remove_placement_group(self._placement_group)
774
+ self._placement_group = None
775
+ self.start(initialization_hook=initialization_hook)
776
+
777
+ def _increment_failures(self):
778
+ self._num_failures += 1
779
+ if self._num_failures >= self._max_failures:
780
+ failure = self._last_failure
781
+ self._last_failure = None
782
+ if self._max_failures > 0:
783
+ exc = RuntimeError(
784
+ "Training has failed after " f"{self._num_failures} " "attempts."
785
+ )
786
+ raise exc.with_traceback(None) from failure
787
+ else:
788
+ raise failure
789
+
790
+ def get_worker_group(self):
791
+ return self.worker_group
792
+
793
+ def _get_num_failures(self):
794
+ return self._num_failures
795
+
796
+
797
+ class InactiveWorkerGroupError(Exception):
798
+ """Raised when underlying worker group is inactive."""
799
+
800
+
801
+ class InactiveWorkerGroup:
802
+ # TODO: fix inheritence. perhaps create WorkerGroupInterface.
803
+
804
+ # Need to define getstate and setstate so that getattr does not screwup
805
+ # pickling. See https://stackoverflow.com/a/50888571/11249691
806
+ def __getstate__(self):
807
+ return vars(self)
808
+
809
+ def __setstate__(self, state):
810
+ vars(self).update(state)
811
+
812
+ def __getattr__(self, name):
813
+ raise InactiveWorkerGroupError()
814
+
815
+ def __len__(self):
816
+ raise InactiveWorkerGroupError()
817
+
818
+
819
+ def _get_session(method_name: str):
820
+ # Get the session for this worker.
821
+ session = get_session()
822
+ if not session:
823
+ # Session is not initialized yet.
824
+ raise TrainBackendError(
825
+ f"`{method_name}` has been called "
826
+ "before `start_training`. Please call "
827
+ "`start_training` before "
828
+ f"`{method_name}`."
829
+ )
830
+ return session
.venv/lib/python3.11/site-packages/ray/train/_internal/checkpoint_manager.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import numbers
3
+ from typing import Any, Callable, List, Optional, Tuple
4
+
5
+ from ray._private.dict import flatten_dict
6
+ from ray.air._internal.util import is_nan
7
+ from ray.air.config import MAX
8
+ from ray.train import CheckpointConfig
9
+ from ray.train._internal.session import _TrainingResult
10
+ from ray.train._internal.storage import _delete_fs_path
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ def _insert_into_sorted_list(list: List[Any], item: Any, key: Callable[[Any], Any]):
16
+ """Insert an item into a sorted list with a custom key function.
17
+
18
+ Examples:
19
+
20
+ >>> list = []
21
+ >>> _insert_into_sorted_list(list, {"a": 1, "b": 0}, lambda x: x["a"])
22
+ >>> list
23
+ [{'a': 1, 'b': 0}]
24
+ >>> _insert_into_sorted_list(list, {"a": 3, "b": 1}, lambda x: x["a"])
25
+ >>> list
26
+ [{'a': 1, 'b': 0}, {'a': 3, 'b': 1}]
27
+ >>> _insert_into_sorted_list(list, {"a": 4, "b": 2}, lambda x: x["a"])
28
+ >>> list
29
+ [{'a': 1, 'b': 0}, {'a': 3, 'b': 1}, {'a': 4, 'b': 2}]
30
+ >>> _insert_into_sorted_list(list, {"a": 1, "b": 3}, lambda x: x["a"])
31
+ >>> list
32
+ [{'a': 1, 'b': 0}, {'a': 1, 'b': 3}, {'a': 3, 'b': 1}, {'a': 4, 'b': 2}]
33
+ """
34
+ i = 0
35
+ while i < len(list):
36
+ # Insert to the right of all duplicates.
37
+ if key(list[i]) > key(item):
38
+ break
39
+ i += 1
40
+ list.insert(i, item)
41
+
42
+
43
+ class _CheckpointManager:
44
+ """Checkpoint manager that handles checkpoint book-keeping for a trial.
45
+
46
+ The main purpose of this abstraction is to keep the top K checkpoints based on
47
+ recency/a user-provided metric.
48
+
49
+ NOTE: This class interacts with `_TrainingResult` objects, which are
50
+ (checkpoint, metrics) pairs. This is to order checkpoints by metrics.
51
+
52
+ Args:
53
+ checkpoint_config: Defines how many and which checkpoints to keep.
54
+ """
55
+
56
+ def __init__(self, checkpoint_config: Optional[CheckpointConfig]):
57
+ self._checkpoint_config = checkpoint_config or CheckpointConfig()
58
+
59
+ # List of checkpoints ordered by ascending score.
60
+ self._checkpoint_results: List[_TrainingResult] = []
61
+
62
+ # The latest registered checkpoint.
63
+ # This should never be immediately deleted upon registration,
64
+ # even if it's not in the top K checkpoints, based on score.
65
+ self._latest_checkpoint_result: Optional[_TrainingResult] = None
66
+
67
+ if (
68
+ self._checkpoint_config.num_to_keep is not None
69
+ and self._checkpoint_config.num_to_keep <= 0
70
+ ):
71
+ raise ValueError(
72
+ f"`num_to_keep` must >= 1, got: "
73
+ f"{self._checkpoint_config.num_to_keep}"
74
+ )
75
+
76
+ @property
77
+ def checkpoint_config(self):
78
+ return self._checkpoint_config
79
+
80
+ def register_checkpoint(self, checkpoint_result: _TrainingResult):
81
+ """Register new checkpoint and add to bookkeeping.
82
+
83
+ This method will register a new checkpoint and add it to the internal
84
+ bookkeeping logic. This means the checkpoint manager will decide if
85
+ this checkpoint should be kept, and if older or worse performing
86
+ checkpoints should be deleted.
87
+
88
+ Args:
89
+ checkpoint: Tracked checkpoint object to add to bookkeeping.
90
+ """
91
+ self._latest_checkpoint_result = checkpoint_result
92
+
93
+ if self._checkpoint_config.checkpoint_score_attribute is not None:
94
+ # If we're ordering by a score, insert the checkpoint
95
+ # so that the list remains sorted.
96
+ _insert_into_sorted_list(
97
+ self._checkpoint_results,
98
+ checkpoint_result,
99
+ key=self._get_checkpoint_score,
100
+ )
101
+ else:
102
+ # If no metric is provided, just append (ordering by time of registration).
103
+ self._checkpoint_results.append(checkpoint_result)
104
+
105
+ if self._checkpoint_config.num_to_keep is not None:
106
+ # Delete the bottom (N - K) checkpoints
107
+ worst_results = set(
108
+ self._checkpoint_results[: -self._checkpoint_config.num_to_keep]
109
+ )
110
+ # Except for the latest checkpoint.
111
+ results_to_delete = worst_results - {self._latest_checkpoint_result}
112
+
113
+ # Update internal state before actually deleting them.
114
+ self._checkpoint_results = [
115
+ checkpoint_result
116
+ for checkpoint_result in self._checkpoint_results
117
+ if checkpoint_result not in results_to_delete
118
+ ]
119
+
120
+ for checkpoint_result in results_to_delete:
121
+ checkpoint = checkpoint_result.checkpoint
122
+ logger.debug("Deleting checkpoint: ", checkpoint)
123
+ _delete_fs_path(fs=checkpoint.filesystem, fs_path=checkpoint.path)
124
+
125
+ def _get_checkpoint_score(
126
+ self, checkpoint: _TrainingResult
127
+ ) -> Tuple[bool, numbers.Number]:
128
+ """Get the score for a checkpoint, according to checkpoint config.
129
+
130
+ If `mode="min"`, the metric is negated so that the lowest score is
131
+ treated as the best.
132
+
133
+ Returns:
134
+ Tuple: A tuple of (not_is_nan: bool, score: numbers.Number).
135
+ This score orders: nan values < float("-inf") < valid numeric metrics
136
+ """
137
+ checkpoint_score_attribute = self._checkpoint_config.checkpoint_score_attribute
138
+ if checkpoint_score_attribute:
139
+ flat_metrics = flatten_dict(checkpoint.metrics)
140
+ try:
141
+ checkpoint_result = flat_metrics[checkpoint_score_attribute]
142
+ except KeyError:
143
+ valid_keys = list(flat_metrics.keys())
144
+ logger.error(
145
+ f"Result dict has no key: {checkpoint_score_attribute}. "
146
+ f"checkpoint_score_attr must be set to a key in the "
147
+ f"result dict. Valid keys are: {valid_keys}"
148
+ )
149
+ checkpoint_result = float("-inf")
150
+ else:
151
+ checkpoint_result = float("-inf")
152
+
153
+ checkpoint_score_order = self._checkpoint_config.checkpoint_score_order
154
+ order_factor = 1.0 if checkpoint_score_order == MAX else -1.0
155
+
156
+ checkpoint_score = order_factor * checkpoint_result
157
+
158
+ if not isinstance(checkpoint_score, numbers.Number):
159
+ raise ValueError(
160
+ f"Unable to persist checkpoint for "
161
+ f"checkpoint_score_attribute: "
162
+ f"{checkpoint_score_attribute} with value "
163
+ f"{checkpoint_score}. "
164
+ f"This attribute must be numerical."
165
+ )
166
+
167
+ return (
168
+ (not is_nan(checkpoint_score), checkpoint_score)
169
+ if not is_nan(checkpoint_score)
170
+ else (False, float("-inf"))
171
+ )
172
+
173
+ @property
174
+ def best_checkpoint_result(self) -> Optional[_TrainingResult]:
175
+ return self._checkpoint_results[-1] if self._checkpoint_results else None
176
+
177
+ @property
178
+ def latest_checkpoint_result(self) -> Optional[_TrainingResult]:
179
+ return self._latest_checkpoint_result
180
+
181
+ @property
182
+ def best_checkpoint_results(self) -> List[_TrainingResult]:
183
+ if self._checkpoint_config.num_to_keep is None:
184
+ return self._checkpoint_results
185
+ return self._checkpoint_results[-self._checkpoint_config.num_to_keep :]
.venv/lib/python3.11/site-packages/ray/train/_internal/data_config.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ from typing import Dict, List, Literal, Optional, Union
3
+
4
+ import ray
5
+ from ray.actor import ActorHandle
6
+ from ray.data import DataIterator, Dataset, ExecutionOptions, NodeIdStr
7
+ from ray.data._internal.execution.interfaces.execution_options import ExecutionResources
8
+ from ray.util.annotations import DeveloperAPI, PublicAPI
9
+
10
+
11
+ @PublicAPI(stability="stable")
12
+ class DataConfig:
13
+ """Class responsible for configuring Train dataset preprocessing.
14
+
15
+ For advanced use cases, this class can be subclassed and the `configure()` method
16
+ overriden for custom data preprocessing.
17
+ """
18
+
19
+ def __init__(
20
+ self,
21
+ datasets_to_split: Union[Literal["all"], List[str]] = "all",
22
+ execution_options: Optional[ExecutionOptions] = None,
23
+ ):
24
+ """Construct a DataConfig.
25
+
26
+ Args:
27
+ datasets_to_split: Specifies which datasets should be split among workers.
28
+ Can be set to "all" or a list of dataset names. Defaults to "all",
29
+ i.e. split all datasets.
30
+ execution_options: The execution options to pass to Ray Data. By default,
31
+ the options will be optimized for data ingest. When overriding this,
32
+ base your options off of `DataConfig.default_ingest_options()`.
33
+ """
34
+ if isinstance(datasets_to_split, list) or datasets_to_split == "all":
35
+ self._datasets_to_split = datasets_to_split
36
+ else:
37
+ raise TypeError(
38
+ "`datasets_to_split` should be a 'all' or a list of strings of "
39
+ "dataset names. Received "
40
+ f"{type(datasets_to_split).__name__} with value {datasets_to_split}."
41
+ )
42
+
43
+ self._execution_options: ExecutionOptions = (
44
+ execution_options or DataConfig.default_ingest_options()
45
+ )
46
+
47
+ self._num_train_cpus = 0.0
48
+ self._num_train_gpus = 0.0
49
+
50
+ def set_train_total_resources(self, num_train_cpus: float, num_train_gpus: float):
51
+ """Set the total number of CPUs and GPUs used by training.
52
+
53
+ If CPU or GPU resource limits are not set, they will be set to the
54
+ total cluster resources minus the resources used by training.
55
+ """
56
+ # TODO: We may also include other resources besides CPU and GPU.
57
+ self._num_train_cpus = num_train_cpus
58
+ self._num_train_gpus = num_train_gpus
59
+
60
+ @DeveloperAPI
61
+ def configure(
62
+ self,
63
+ datasets: Dict[str, Dataset],
64
+ world_size: int,
65
+ worker_handles: Optional[List[ActorHandle]],
66
+ worker_node_ids: Optional[List[NodeIdStr]],
67
+ **kwargs,
68
+ ) -> List[Dict[str, DataIterator]]:
69
+ """Configure how Train datasets should be assigned to workers.
70
+
71
+ Args:
72
+ datasets: The datasets dict passed to Train by the user.
73
+ world_size: The number of Train workers in total.
74
+ worker_handles: The actor handles of the Train workers.
75
+ worker_node_ids: The node ids of the Train workers.
76
+ kwargs: Forwards compatibility placeholder.
77
+
78
+ Returns:
79
+ A list of dataset splits for each worker. The size of the list must be
80
+ equal to `world_size`. Each element of the list contains the assigned
81
+ `DataIterator` instances by name for the worker.
82
+ """
83
+ output = [{} for _ in range(world_size)]
84
+
85
+ if self._datasets_to_split == "all":
86
+ datasets_to_split = set(datasets.keys())
87
+ else:
88
+ datasets_to_split = set(self._datasets_to_split)
89
+
90
+ locality_hints = (
91
+ worker_node_ids if self._execution_options.locality_with_output else None
92
+ )
93
+ for name, ds in datasets.items():
94
+ execution_options = copy.deepcopy(self._execution_options)
95
+
96
+ if execution_options.is_resource_limits_default():
97
+ # If "resource_limits" is not overriden by the user,
98
+ # add training-reserved resources to Data's exclude_resources.
99
+ execution_options.exclude_resources = (
100
+ execution_options.exclude_resources.add(
101
+ ExecutionResources(
102
+ cpu=self._num_train_cpus, gpu=self._num_train_gpus
103
+ )
104
+ )
105
+ )
106
+
107
+ ds = ds.copy(ds)
108
+ ds.context.execution_options = execution_options
109
+
110
+ if name in datasets_to_split:
111
+ for i, split in enumerate(
112
+ ds.streaming_split(
113
+ world_size, equal=True, locality_hints=locality_hints
114
+ )
115
+ ):
116
+ output[i][name] = split
117
+ else:
118
+ for i in range(world_size):
119
+ output[i][name] = ds.iterator()
120
+
121
+ return output
122
+
123
+ @staticmethod
124
+ def default_ingest_options() -> ExecutionOptions:
125
+ """The default Ray Data options used for data ingest.
126
+
127
+ By default, configurations are carried over from what is already set
128
+ in DataContext.
129
+ """
130
+ ctx = ray.data.DataContext.get_current()
131
+ return ExecutionOptions(
132
+ # TODO(hchen): Re-enable `locality_with_output` by default after fixing
133
+ # https://github.com/ray-project/ray/issues/40607
134
+ locality_with_output=ctx.execution_options.locality_with_output,
135
+ resource_limits=ctx.execution_options.resource_limits,
136
+ exclude_resources=ctx.execution_options.exclude_resources,
137
+ preserve_order=ctx.execution_options.preserve_order,
138
+ verbose_progress=ctx.execution_options.verbose_progress,
139
+ )
.venv/lib/python3.11/site-packages/ray/train/_internal/dl_predictor.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+ from typing import Dict, Optional, TypeVar, Union
3
+
4
+ import numpy as np
5
+ import pandas as pd
6
+
7
+ from ray.air.util.data_batch_conversion import (
8
+ BatchFormat,
9
+ _convert_batch_type_to_pandas,
10
+ _convert_pandas_to_batch_type,
11
+ )
12
+ from ray.train.predictor import Predictor
13
+ from ray.util.annotations import DeveloperAPI
14
+
15
+ TensorType = TypeVar("TensorType")
16
+ TensorDtype = TypeVar("TensorDtype")
17
+
18
+
19
+ class DLPredictor(Predictor):
20
+ @abc.abstractmethod
21
+ def _arrays_to_tensors(
22
+ self,
23
+ numpy_arrays: Union[np.ndarray, Dict[str, np.ndarray]],
24
+ dtype: Optional[Union[TensorDtype, Dict[str, TensorDtype]]],
25
+ ) -> Union[TensorType, Dict[str, TensorType]]:
26
+ """Converts a NumPy ndarray batch to the tensor type for the DL framework.
27
+
28
+ Args:
29
+ numpy_array: The numpy array to convert to a tensor.
30
+ dtype: The tensor dtype to use when creating the DL tensor.
31
+ ndarray: A (dict of) NumPy ndarray(s) that we wish to convert to a (dict of)
32
+ tensor(s).
33
+ dtype: A (dict of) tensor dtype(s) to use when creating the DL tensor; if
34
+ None, the dtype will be inferred from the NumPy ndarray data.
35
+
36
+ Returns:
37
+ A deep learning framework specific tensor.
38
+ """
39
+ raise NotImplementedError
40
+
41
+ @abc.abstractmethod
42
+ def _tensor_to_array(self, tensor: TensorType) -> np.ndarray:
43
+ """Converts tensor framework specific tensor to a numpy array.
44
+
45
+ Args:
46
+ tensor: A framework specific tensor.
47
+
48
+ Returns:
49
+ A numpy array representing the input tensor.
50
+ """
51
+
52
+ raise NotImplementedError
53
+
54
+ @abc.abstractmethod
55
+ @DeveloperAPI
56
+ def call_model(
57
+ self, inputs: Union[TensorType, Dict[str, TensorType]]
58
+ ) -> Union[TensorType, Dict[str, TensorType]]:
59
+ """Inputs the tensor to the model for this Predictor and returns the result.
60
+
61
+ Args:
62
+ inputs: The tensor to input to the model.
63
+
64
+ Returns:
65
+ A tensor or dictionary of tensors containing the model output.
66
+ """
67
+ raise NotImplementedError
68
+
69
+ @classmethod
70
+ @DeveloperAPI
71
+ def preferred_batch_format(cls) -> BatchFormat:
72
+ return BatchFormat.NUMPY
73
+
74
+ def _predict_pandas(
75
+ self,
76
+ data: pd.DataFrame,
77
+ dtype: Optional[Union[TensorDtype, Dict[str, TensorDtype]]],
78
+ ) -> pd.DataFrame:
79
+ numpy_input = _convert_pandas_to_batch_type(
80
+ data,
81
+ BatchFormat.NUMPY,
82
+ self._cast_tensor_columns,
83
+ )
84
+ numpy_output = self._predict_numpy(numpy_input, dtype)
85
+ return _convert_batch_type_to_pandas(numpy_output)
86
+
87
+ def _predict_numpy(
88
+ self,
89
+ data: Union[np.ndarray, Dict[str, np.ndarray]],
90
+ dtype: Optional[Union[TensorDtype, Dict[str, TensorDtype]]],
91
+ ) -> Union[np.ndarray, Dict[str, np.ndarray]]:
92
+ # Single column selection return numpy array so preprocessors can be
93
+ # reused in both training and prediction
94
+ if isinstance(data, dict) and len(data) == 1:
95
+ data = next(iter(data.values()))
96
+ model_input = self._arrays_to_tensors(data, dtype)
97
+ model_output = self.call_model(model_input)
98
+ # TODO (jiaodong): Investigate perf implication of this.
99
+ # Move DL Tensor to CPU and convert to numpy.
100
+ if isinstance(model_output, dict):
101
+ return {k: self._tensor_to_array(v) for k, v in model_output.items()}
102
+ else:
103
+ return {"predictions": self._tensor_to_array(model_output)}
.venv/lib/python3.11/site-packages/ray/train/_internal/framework_checkpoint.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import ray.cloudpickle as ray_pickle
4
+ from ray._private.utils import binary_to_hex, hex_to_binary
5
+ from ray.data.preprocessor import Preprocessor
6
+ from ray.train._checkpoint import Checkpoint
7
+
8
+ PREPROCESSOR_KEY = "preprocessor_pkl"
9
+
10
+
11
+ class FrameworkCheckpoint(Checkpoint):
12
+ """A checkpoint to preserve the functionality of legacy
13
+ framework-specific checkpoints.
14
+
15
+ Example:
16
+
17
+ >>> import tempfile
18
+ >>> checkpoint = FrameworkCheckpoint(tempfile.mkdtemp())
19
+ >>> checkpoint.get_preprocessor() is None
20
+ True
21
+ >>> preprocessor = Preprocessor()
22
+ >>> preprocessor._attr = 1234
23
+ >>> checkpoint.set_preprocessor(preprocessor)
24
+ >>> checkpoint.get_preprocessor()._attr
25
+ 1234
26
+ """
27
+
28
+ def get_preprocessor(self) -> Optional[Preprocessor]:
29
+ """Return the preprocessor stored in the checkpoint.
30
+
31
+ Returns:
32
+ The preprocessor stored in the checkpoint, or ``None`` if no
33
+ preprocessor was stored.
34
+ """
35
+ metadata = self.get_metadata()
36
+ preprocessor_bytes = metadata.get(PREPROCESSOR_KEY)
37
+ if preprocessor_bytes is None:
38
+ return None
39
+ return ray_pickle.loads(hex_to_binary(preprocessor_bytes))
40
+
41
+ def set_preprocessor(self, preprocessor: Preprocessor):
42
+ """Store a preprocessor with the checkpoint."""
43
+ self.update_metadata(
44
+ {PREPROCESSOR_KEY: binary_to_hex(ray_pickle.dumps(preprocessor))}
45
+ )
.venv/lib/python3.11/site-packages/ray/train/_internal/session.py ADDED
@@ -0,0 +1,1163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import logging
3
+ import os
4
+ import platform
5
+ import queue
6
+ import sys
7
+ import threading
8
+ import time
9
+ import warnings
10
+ from dataclasses import dataclass
11
+ from datetime import datetime
12
+ from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Set, Type
13
+
14
+ import ray
15
+ from ray.air._internal.util import RunnerThread, StartTraceback
16
+ from ray.air.constants import (
17
+ _ERROR_FETCH_TIMEOUT,
18
+ _RESULT_FETCH_TIMEOUT,
19
+ SESSION_MISUSE_LOG_ONCE_KEY,
20
+ TIME_THIS_ITER_S,
21
+ TIMESTAMP,
22
+ )
23
+ from ray.data import Dataset
24
+ from ray.train import Checkpoint
25
+ from ray.train._internal.accelerator import Accelerator
26
+ from ray.train._internal.storage import StorageContext
27
+ from ray.train.constants import (
28
+ CHECKPOINT_DIR_NAME,
29
+ DETAILED_AUTOFILLED_KEYS,
30
+ RAY_CHDIR_TO_TRIAL_DIR,
31
+ TIME_TOTAL_S,
32
+ WORKER_HOSTNAME,
33
+ WORKER_NODE_IP,
34
+ WORKER_PID,
35
+ _v2_migration_warnings_enabled,
36
+ )
37
+ from ray.train.error import SessionMisuseError
38
+ from ray.train.utils import _log_deprecation_warning
39
+ from ray.util.annotations import DeveloperAPI, PublicAPI
40
+ from ray.util.debug import log_once
41
+ from ray.util.placement_group import _valid_resource_shape
42
+ from ray.util.scheduling_strategies import (
43
+ PlacementGroupSchedulingStrategy,
44
+ SchedulingStrategyT,
45
+ )
46
+
47
+ if TYPE_CHECKING:
48
+ from ray.data import DataIterator
49
+ from ray.tune.execution.placement_groups import PlacementGroupFactory
50
+
51
+
52
+ logger = logging.getLogger(__name__)
53
+
54
+
55
+ @dataclass
56
+ class TrialInfo:
57
+ """The trial information to propagate to TrainSession."""
58
+
59
+ name: str
60
+ id: str
61
+ resources: Dict[str, float]
62
+ logdir: str
63
+ driver_ip: str
64
+ driver_node_id: str
65
+ experiment_name: Optional[str] = None
66
+ run_id: Optional[str] = None
67
+
68
+
69
+ class _FutureTrainingResult:
70
+ """A future that will be resolved to a `_TrainingResult`.
71
+
72
+ This is needed for specific schedulers such as PBT that schedule saves.
73
+
74
+ This wrapper should be removed after refactoring PBT to not schedule saves anymore.
75
+ """
76
+
77
+ def __init__(self, future: ray.ObjectRef):
78
+ self.future = future
79
+
80
+ def resolve(self, block: bool = True) -> Optional["_TrainingResult"]:
81
+ """Resolve into ``_TrainingResult``.
82
+
83
+ This will return None for function trainables if no checkpoint has been
84
+ saved before.
85
+ """
86
+ if block:
87
+ timeout = None
88
+ else:
89
+ timeout = 1e-9
90
+ try:
91
+ return ray.get(self.future, timeout=timeout)
92
+ except TimeoutError:
93
+ # Not ready, yet
94
+ pass
95
+ except Exception as exc:
96
+ logger.error(f"Error resolving result: {exc}")
97
+
98
+
99
+ class _TrainingResult:
100
+ """A (checkpoint, metrics) result reported by the user."""
101
+
102
+ def __init__(self, checkpoint: Optional[Checkpoint], metrics: Dict[str, Any]):
103
+ self.checkpoint = checkpoint
104
+ self.metrics = metrics
105
+
106
+ def __repr__(self) -> str:
107
+ return f"TrainingResult(checkpoint={self.checkpoint}, metrics={self.metrics})"
108
+
109
+
110
+ # TODO(xwjiang): This needs a better name.
111
+ @DeveloperAPI
112
+ class _TrainSession:
113
+ """Holds information for training on each worker."""
114
+
115
+ def __init__(
116
+ self,
117
+ training_func: Callable,
118
+ world_rank: Optional[int],
119
+ local_rank: Optional[int],
120
+ node_rank: Optional[int],
121
+ local_world_size: Optional[int],
122
+ world_size: Optional[int],
123
+ trial_info: Optional[TrialInfo] = None,
124
+ dataset_shard: Optional[Dict[str, Dataset]] = None,
125
+ metadata: Dict[str, Any] = None,
126
+ checkpoint: Optional[Checkpoint] = None,
127
+ detailed_autofilled_metrics: bool = False,
128
+ storage: Optional[StorageContext] = None,
129
+ synchronous_result_reporting: bool = False,
130
+ ):
131
+ # `synchronous_result_reporting` refers to whether or not the
132
+ # training function is immediately unblocked to continue running
133
+ # after the main thread receives its result.
134
+ # Ex 1: For 2 Ray Train workers with synchronous_result_reporting=True,
135
+ # the worker that produces a result first will immediately will continue
136
+ # onto the next iteration.
137
+ # Ex 2: For a Tune function Trainable with `synchronous_result_reporting=False`,
138
+ # training will only continue with an explicit call to `session.get_next`.
139
+ # Synchronous reporting in example 2 is needed for Tune schedulers to
140
+ # be able to stop the execution of the training function at will,
141
+ # for advanced pausing schedulers (PBT, BOHB) and actor reuse.
142
+ self.synchronous_result_reporting = synchronous_result_reporting
143
+
144
+ # Ray Train worker properties
145
+ # Note: These are set to None for Tune function Trainables.
146
+ self.dataset_shard = dataset_shard
147
+ self.metadata = metadata
148
+
149
+ self.world_rank = world_rank
150
+ self.local_rank = local_rank
151
+ self.node_rank = node_rank
152
+ self.local_world_size = local_world_size
153
+ self.world_size = world_size
154
+
155
+ assert storage
156
+ logger.debug(f"StorageContext on SESSION (rank={world_rank}):\n{storage}")
157
+
158
+ # NOTE: `reset` will initialize many properties needed to start running the
159
+ # training_func as a thread.
160
+ self.reset(
161
+ training_func=training_func,
162
+ trial_info=trial_info,
163
+ storage=storage,
164
+ loaded_checkpoint=checkpoint,
165
+ )
166
+
167
+ # Autofilled metrics attributes.
168
+ self.detailed_autofilled_metrics = detailed_autofilled_metrics
169
+ self.last_report_time = time.time()
170
+ self.iteration = 0
171
+ self.time_total = 0.0
172
+ self.local_ip = self.get_current_ip()
173
+
174
+ self.accelerator = None
175
+ self._state = {}
176
+
177
+ def get_state(self, key: str) -> Any:
178
+ return self._state.get(key)
179
+
180
+ def set_state(self, key: str, value: Any):
181
+ self._state[key] = value
182
+
183
+ def get_current_ip(self):
184
+ self.local_ip = ray.util.get_node_ip_address()
185
+ return self.local_ip
186
+
187
+ def start(self):
188
+ """Starts the training thread."""
189
+ self.training_started = True
190
+ self.training_thread.start()
191
+
192
+ def reset(
193
+ self,
194
+ training_func: Callable,
195
+ trial_info: TrialInfo,
196
+ storage: StorageContext,
197
+ loaded_checkpoint=None,
198
+ ):
199
+ # This lock is used to control the execution of the training thread.
200
+ self.continue_lock = threading.Semaphore(0)
201
+
202
+ # This event is used to signal the training thread to stop.
203
+ self.stop_event = threading.Event()
204
+
205
+ # Queue for sending results across threads.
206
+ self.result_queue = queue.Queue(1)
207
+
208
+ # Queue for raising exceptions from runner thread to main thread.
209
+ # The error queue has a max size of one to prevent stacking error and force
210
+ # error reporting to block until finished.
211
+ self.error_queue = queue.Queue(1)
212
+
213
+ # The Thread object that is running the training function.
214
+ self.training_thread = RunnerThread(
215
+ target=training_func, daemon=True, error_queue=self.error_queue
216
+ )
217
+
218
+ # Possibly override with new state
219
+ self.trial_info = trial_info
220
+ self.storage = storage
221
+ self.loaded_checkpoint = loaded_checkpoint
222
+
223
+ # Reset state
224
+ self._state = {}
225
+ self.ignore_report = False
226
+ self.training_started = False
227
+ self._first_report = True
228
+
229
+ # Change the working directory to a special trial folder.
230
+ # This is to ensure that all Ray Train workers have a common working directory.
231
+ os.makedirs(storage.trial_working_directory, exist_ok=True)
232
+ if bool(int(os.environ.get(RAY_CHDIR_TO_TRIAL_DIR, "1"))):
233
+ logger.debug(
234
+ f"Changing the working directory to: {storage.trial_working_directory}"
235
+ )
236
+ os.chdir(storage.trial_working_directory)
237
+
238
+ def pause_reporting(self):
239
+ """Ignore all future ``session.report()`` calls."""
240
+ self.ignore_report = True
241
+
242
+ def finish(self, timeout: Optional[float] = None) -> Optional[Any]:
243
+ """Finishes the training thread.
244
+
245
+ Raises any Exception from training.
246
+ """
247
+ # Set the stop event for the training thread to gracefully exit.
248
+ self.stop_event.set()
249
+
250
+ # Release the lock so that training thread can process this event.
251
+ self.continue_lock.release()
252
+
253
+ # Force a final (blocking) sync of artifacts in the trial path to storage.
254
+ self.storage.persist_artifacts(force=True)
255
+
256
+ # Wait for training to finish.
257
+ # This will raise any errors that occur during training, including SystemError
258
+ # This returns the result of the training function.
259
+ output = None
260
+ if self.training_started:
261
+ output = self.training_thread.join(timeout=timeout)
262
+
263
+ return output
264
+
265
+ def get_next(self) -> Optional[_TrainingResult]:
266
+ """Gets the next ``_TrainingResult`` from the result queue.
267
+
268
+ If the result queue is empty, then this function returns ``None``.
269
+ """
270
+ if not self.training_started:
271
+ raise RuntimeError("Please call start before calling get_next.")
272
+
273
+ if self.synchronous_result_reporting:
274
+ # There's no need to release the lock on the first report
275
+ # since `start` already started the training thread.
276
+ if not self._first_report:
277
+ # Release the lock to trigger training to continue,
278
+ # until the next call to report.
279
+ self.continue_lock.release()
280
+ self._first_report = False
281
+
282
+ result = None
283
+ # While training is still ongoing, attempt to get the result.
284
+ while result is None and self.training_thread.is_alive():
285
+ try:
286
+ result = self.result_queue.get(
287
+ block=True, timeout=_RESULT_FETCH_TIMEOUT
288
+ )
289
+ except queue.Empty:
290
+ pass
291
+
292
+ # If no result was found, then the runner must no longer be alive.
293
+ if result is None:
294
+ # Try one last time to fetch results in case results were
295
+ # reported in between the time of the last check and the
296
+ # termination of the thread runner.
297
+ try:
298
+ result = self.result_queue.get(
299
+ block=False, timeout=_RESULT_FETCH_TIMEOUT
300
+ )
301
+ except queue.Empty:
302
+ pass
303
+
304
+ # check if error occurred inside the thread runner.
305
+ if result is None:
306
+ # only raise an error from the runner if all results are consumed
307
+ self._report_thread_runner_error(block=True)
308
+ else:
309
+ if not self.error_queue.empty():
310
+ logger.debug(
311
+ (
312
+ "Runner error waiting to be raised in main thread. "
313
+ "Logging all available results first."
314
+ )
315
+ )
316
+
317
+ if not self.synchronous_result_reporting:
318
+ # At this point, the training thread has reached
319
+ # the `train.report` and is blocked there.
320
+ # If performing asynchronous result reporting,
321
+ # release the lock to allow each worker to keep training
322
+ # immediately after the coordinator fetches their result.
323
+ self.continue_lock.release()
324
+
325
+ # Return None if there are no more results to fetch.
326
+ return result
327
+
328
+ def _auto_fill_metrics(self, result: dict) -> dict:
329
+ """Add autofilled metrics and update attributes."""
330
+ current_time = time.time()
331
+ current_datetime = datetime.now()
332
+ if TIME_THIS_ITER_S in result:
333
+ time_this_iter = result[TIME_THIS_ITER_S]
334
+ else:
335
+ time_this_iter = current_time - self.last_report_time
336
+ self.iteration += 1
337
+ self.time_total += time_this_iter
338
+ self.last_report_time = current_time
339
+
340
+ auto_filled_metrics = {
341
+ TIMESTAMP: int(time.mktime(current_datetime.timetuple())),
342
+ TIME_TOTAL_S: self.time_total,
343
+ WORKER_PID: os.getpid(),
344
+ WORKER_HOSTNAME: platform.node(),
345
+ WORKER_NODE_IP: self.local_ip,
346
+ }
347
+
348
+ if not self.detailed_autofilled_metrics:
349
+ auto_filled_metrics = {
350
+ k: v
351
+ for k, v in auto_filled_metrics.items()
352
+ if k not in DETAILED_AUTOFILLED_KEYS
353
+ }
354
+
355
+ result = result.copy()
356
+ result.update(auto_filled_metrics)
357
+ return result
358
+
359
+ def _auto_fill_checkpoint_metrics(self, result: dict) -> dict:
360
+ """Add autofilled metrics and update attributes."""
361
+ current_datetime = datetime.now()
362
+
363
+ auto_filled_metrics = {
364
+ TIMESTAMP: int(time.mktime(current_datetime.timetuple()))
365
+ }
366
+ result = result.copy()
367
+ result.update(auto_filled_metrics)
368
+ return result
369
+
370
+ def _report_thread_runner_error(self, block=False):
371
+ try:
372
+ e = self.error_queue.get(block=block, timeout=_ERROR_FETCH_TIMEOUT)
373
+ raise StartTraceback from e
374
+ except queue.Empty:
375
+ pass
376
+
377
+ def _report_training_result(self, training_result: _TrainingResult) -> None:
378
+ """Place a training result on the result queue for the main thread to process,
379
+ then block until the main thread signals that training should continue.
380
+
381
+ NOTE: This is used internally to report results from Train to Tune
382
+ without persisting checkpoints to storage 2 times.
383
+ `report` is the public API that directly persists to storage, which
384
+ should only be called by user code.
385
+ """
386
+ if training_result.checkpoint:
387
+ # NOTE: This populates `train.get_checkpoint`
388
+ self.loaded_checkpoint = training_result.checkpoint
389
+
390
+ # Add result to a thread-safe queue.
391
+ self.result_queue.put(training_result, block=True)
392
+
393
+ # Acquire lock to stop the training thread until main thread
394
+ # triggers resume.
395
+ self.continue_lock.acquire()
396
+
397
+ # If the trial should be terminated, exit gracefully.
398
+ # NOTE: This is only really useful if `synchronous_result_reporting=True`.
399
+ # Otherwise, the lock is immediately released on reporting, and this
400
+ # check is skipped before the main thread decides to set the stop event.
401
+ if self.stop_event.is_set():
402
+ self.stop_event.clear()
403
+ sys.exit(0)
404
+
405
+ def report(self, metrics: Dict, checkpoint: Optional[Checkpoint] = None) -> None:
406
+ # Special case: early fail for Torch tensors
407
+ if "torch" in sys.modules:
408
+ from ray.air._internal.torch_utils import contains_tensor
409
+
410
+ if contains_tensor(metrics):
411
+ raise ValueError(
412
+ "Passing objects containg Torch tensors as metrics "
413
+ "is not supported as it will throw an exception on "
414
+ "deserialization. You can either convert the tensors "
415
+ "to Python objects or report a `train.Checkpoint` "
416
+ "with `ray.train.report` to store your Torch objects."
417
+ )
418
+
419
+ if self.ignore_report:
420
+ return
421
+
422
+ metrics = self._auto_fill_metrics(metrics)
423
+
424
+ persisted_checkpoint = None
425
+ if checkpoint:
426
+ self.storage._update_checkpoint_index(metrics)
427
+
428
+ # Persist the reported checkpoint files to storage.
429
+ persisted_checkpoint = self.storage.persist_current_checkpoint(checkpoint)
430
+
431
+ metrics[CHECKPOINT_DIR_NAME] = self.storage.checkpoint_dir_name
432
+ else:
433
+ metrics[CHECKPOINT_DIR_NAME] = None
434
+
435
+ # Persist trial artifacts to storage.
436
+ force_artifact_sync = (
437
+ persisted_checkpoint
438
+ and self.storage.sync_config.sync_artifacts_on_checkpoint
439
+ )
440
+ self.storage.persist_artifacts(force=force_artifact_sync)
441
+
442
+ # Set additional user metadata from the Trainer.
443
+ if persisted_checkpoint and self.metadata:
444
+ user_metadata = persisted_checkpoint.get_metadata()
445
+ for k, v in self.metadata.items():
446
+ # Update keys not already set by the user. This gives user-set keys
447
+ # precedence over keys set at the Trainer level.
448
+ if k not in user_metadata:
449
+ user_metadata[k] = v
450
+ persisted_checkpoint.set_metadata(user_metadata)
451
+
452
+ result = _TrainingResult(checkpoint=persisted_checkpoint, metrics=metrics)
453
+
454
+ self._report_training_result(result)
455
+
456
+ @property
457
+ def experiment_name(self) -> str:
458
+ return self.trial_info.experiment_name
459
+
460
+ @property
461
+ def trial_name(self) -> str:
462
+ return self.trial_info.name
463
+
464
+ @property
465
+ def trial_id(self) -> str:
466
+ return self.trial_info.id
467
+
468
+ @property
469
+ def run_id(self) -> str:
470
+ return self.trial_info.run_id
471
+
472
+ @property
473
+ def trial_resources(self) -> "PlacementGroupFactory":
474
+ return self.trial_info.resources
475
+
476
+ @property
477
+ def trial_dir(self) -> str:
478
+ return self.trial_info.logdir
479
+
480
+ def get_dataset_shard(
481
+ self,
482
+ dataset_name: Optional[str] = None,
483
+ ) -> Optional["DataIterator"]:
484
+ shard = self.dataset_shard
485
+ if shard is None:
486
+ warnings.warn(
487
+ "No dataset passed in. Returning None. Make sure to "
488
+ "pass in a Dataset to Trainer.run to use this "
489
+ "function."
490
+ )
491
+ elif isinstance(shard, dict):
492
+ if not dataset_name:
493
+ raise RuntimeError(
494
+ "Multiple datasets were passed into ``Trainer``, "
495
+ "but no ``dataset_name`` is passed into "
496
+ "``get_dataset_shard``. Please specify which "
497
+ "dataset shard to retrieve."
498
+ )
499
+ return shard.get(dataset_name)
500
+ return shard
501
+
502
+
503
+ # Cache of resource dicts that have been checked by the launch hook already.
504
+ _checked_resources: Set[frozenset] = set()
505
+
506
+ # Global _TrainSession object initialized by Ray Tune function trainables
507
+ # and Ray Train V1 workers.
508
+ _session: Optional[_TrainSession] = None
509
+
510
+
511
+ def _tune_task_and_actor_launch_hook(
512
+ fn, resources: Dict[str, float], strategy: Optional[SchedulingStrategyT]
513
+ ):
514
+ """Launch hook to catch nested tasks that can't fit in the placement group.
515
+
516
+ This gives users a nice warning in case they launch a nested task in a Tune trial
517
+ without reserving resources in the trial placement group to fit it.
518
+ """
519
+
520
+ # Already checked, skip for performance reasons.
521
+ key = frozenset({(k, v) for k, v in resources.items() if v > 0})
522
+ if not key or key in _checked_resources:
523
+ return
524
+
525
+ # No need to check if placement group is None.
526
+ if (
527
+ not isinstance(strategy, PlacementGroupSchedulingStrategy)
528
+ or strategy.placement_group is None
529
+ ):
530
+ return
531
+
532
+ # Check if the resource request is targeting the current placement group.
533
+ cur_pg = ray.util.get_current_placement_group()
534
+ if not cur_pg or strategy.placement_group.id != cur_pg.id:
535
+ return
536
+
537
+ _checked_resources.add(key)
538
+
539
+ # Check if the request can be fulfilled by the current placement group.
540
+ pgf = get_trial_resources()
541
+
542
+ if pgf.head_bundle_is_empty:
543
+ available_bundles = cur_pg.bundle_specs[0:]
544
+ else:
545
+ available_bundles = cur_pg.bundle_specs[1:]
546
+
547
+ # Check if the request can be fulfilled by the current placement group.
548
+ if _valid_resource_shape(resources, available_bundles):
549
+ return
550
+
551
+ if fn.class_name:
552
+ submitted = "actor"
553
+ name = fn.module_name + "." + fn.class_name + "." + fn.function_name
554
+ else:
555
+ submitted = "task"
556
+ name = fn.module_name + "." + fn.function_name
557
+
558
+ # Normalize the resource spec so it looks the same as the placement group bundle.
559
+ main_resources = cur_pg.bundle_specs[0]
560
+ resources = {k: float(v) for k, v in resources.items() if v > 0}
561
+
562
+ raise RuntimeError(
563
+ f"No trial resources are available for launching the {submitted} `{name}`. "
564
+ "To resolve this, specify the Tune option:\n\n"
565
+ "> resources_per_trial=tune.PlacementGroupFactory(\n"
566
+ f"> [{main_resources}] + [{resources}] * N\n"
567
+ "> )\n\n"
568
+ f"Where `N` is the number of slots to reserve for trial {submitted}s. "
569
+ "If you are using a Ray training library, there might be a utility function "
570
+ "to set this automatically for you. For more information, refer to "
571
+ "https://docs.ray.io/en/latest/tune/tutorials/tune-resources.html"
572
+ )
573
+
574
+
575
+ def init_session(*args, **kwargs) -> None:
576
+ global _session
577
+ if _session:
578
+ raise ValueError(
579
+ "A Train session is already in use. Do not call "
580
+ "`init_session()` manually."
581
+ )
582
+
583
+ # Setup hooks for generating placement group resource deadlock warnings.
584
+ from ray import actor, remote_function
585
+
586
+ if "TUNE_DISABLE_RESOURCE_CHECKS" not in os.environ:
587
+ actor._actor_launch_hook = _tune_task_and_actor_launch_hook
588
+ remote_function._task_launch_hook = _tune_task_and_actor_launch_hook
589
+
590
+ _session = _TrainSession(*args, **kwargs)
591
+
592
+
593
+ def get_session() -> Optional[_TrainSession]:
594
+ return _session
595
+
596
+
597
+ def shutdown_session():
598
+ """Shuts down the initialized session."""
599
+ global _session
600
+ _session = None
601
+
602
+
603
+ def _raise_accelerator_session_misuse():
604
+ """Raises a SessionMisuseError because a utility function was used improperly."""
605
+ raise SessionMisuseError(
606
+ "prepare/accelerate utility functions should be called inside a training "
607
+ "function executed by `Trainer.run`"
608
+ )
609
+
610
+
611
+ def get_accelerator(default_accelerator_cls: Type[Accelerator]) -> Accelerator:
612
+ """The accelerator for this training session.
613
+
614
+ If an accelerator has not been set, then this method will construct an
615
+ accelerator using the provided accelerator class.
616
+
617
+ Raises:
618
+ SessionMisuseError: if the session is uninitialized.
619
+ """
620
+ session = get_session()
621
+ if session is None:
622
+ _raise_accelerator_session_misuse()
623
+ if session.accelerator is None:
624
+ session.accelerator = default_accelerator_cls()
625
+ return session.accelerator
626
+
627
+
628
+ def set_accelerator(accelerator: Accelerator) -> None:
629
+ """Sets the accelerator for this training session.
630
+
631
+ Args:
632
+ accelerator: The accelerator to use for training.
633
+
634
+ Raises:
635
+ SessionMisuseError: if the session is unitialized.
636
+ RuntimeError: if the accelerator has already been set.
637
+ """
638
+ session = get_session()
639
+ if session is None:
640
+ _raise_accelerator_session_misuse()
641
+ if session.accelerator is not None:
642
+ raise RuntimeError("Cannot change accelerator once set.")
643
+ session.accelerator = accelerator
644
+
645
+
646
+ def _warn_session_misuse(default_value: Any = None):
647
+ """Warns if fn is being used outside of session and returns ``default_value``."""
648
+
649
+ def inner(fn: Callable):
650
+ fn_name = fn.__name__
651
+
652
+ @functools.wraps(fn)
653
+ def wrapper(*args, **kwargs):
654
+ session = get_session()
655
+ if not session:
656
+ if log_once(f"{SESSION_MISUSE_LOG_ONCE_KEY}-{fn_name}"):
657
+ warnings.warn(
658
+ f"`{fn_name}` is meant to only be "
659
+ "called inside a function that is executed by a Tuner"
660
+ f" or Trainer. Returning `{default_value}`."
661
+ )
662
+ return default_value
663
+ return fn(*args, **kwargs)
664
+
665
+ return wrapper
666
+
667
+ return inner
668
+
669
+
670
+ @PublicAPI(stability="stable")
671
+ @_warn_session_misuse()
672
+ def report(metrics: Dict, *, checkpoint: Optional[Checkpoint] = None) -> None:
673
+ """Report metrics and optionally save a checkpoint.
674
+
675
+ If a checkpoint is provided, it will be
676
+ :ref:`persisted to storage <persistent-storage-guide>`.
677
+
678
+ If this is called in multiple distributed training workers:
679
+
680
+ - Only the metrics reported by the rank 0 worker will be tracked by Ray Train.
681
+ See :ref:`the metrics logging guide <train-monitoring-and-logging>`.
682
+ - A checkpoint will be registered as long as one or more workers reports
683
+ checkpoint that is not None.
684
+ See the :ref:`checkpointing guide <train-dl-saving-checkpoints>`.
685
+ - Checkpoints from multiple workers will be merged into one directory
686
+ in persistent storage.
687
+ See :ref:`the distributed checkpointing guide <train-distributed-checkpointing>`.
688
+
689
+ .. note::
690
+
691
+ Each invocation of this method will automatically increment the underlying
692
+ ``training_iteration`` number. The physical meaning of this "iteration" is
693
+ defined by user depending on how often they call ``report``.
694
+ It does not necessarily map to one epoch.
695
+
696
+ .. warning::
697
+
698
+ All workers must call `ray.train.report` the same number of times
699
+ so that Ray Train can properly synchronize the training state across
700
+ workers. Otherwise, your training will hang.
701
+
702
+ .. warning::
703
+
704
+ This method does NOT act as a barrier for distributed training workers.
705
+ Workers will upload their checkpoint, then continue training immediately.
706
+ If you need to synchronize workers, you can use a framework-native barrier
707
+ such as `torch.distributed.barrier()`.
708
+
709
+ Example:
710
+
711
+ .. testcode::
712
+
713
+ import tempfile
714
+
715
+ from ray import train
716
+ from ray.train import Checkpoint
717
+ from ray.train.torch import TorchTrainer
718
+
719
+
720
+ def train_func(config):
721
+ start_epoch = 0
722
+ checkpoint = train.get_checkpoint()
723
+ if checkpoint:
724
+ with checkpoint.as_directory() as checkpoint_dir:
725
+ # Load back training state
726
+ ...
727
+
728
+ for epoch in range(start_epoch, config.get("num_epochs", 10)):
729
+ # Do training...
730
+
731
+ metrics = {"loss": ...}
732
+
733
+ with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
734
+ # Save the checkpoint...
735
+ # torch.save(...)
736
+
737
+ checkpoint = Checkpoint.from_directory(temp_checkpoint_dir)
738
+
739
+ # Example: Only the rank 0 worker uploads the checkpoint.
740
+ if ray.train.get_context().get_world_rank() == 0:
741
+ train.report(metrics, checkpoint=checkpoint)
742
+ else:
743
+ train.report(metrics, checkpoint=None)
744
+
745
+ trainer = TorchTrainer(
746
+ train_func, scaling_config=train.ScalingConfig(num_workers=2)
747
+ )
748
+
749
+ Args:
750
+ metrics: The metrics you want to report.
751
+ checkpoint: The optional checkpoint you want to report.
752
+ """
753
+ # If we are running in a Tune function, switch to `ray.tune.report`.
754
+ from ray.tune.trainable.trainable_fn_utils import _in_tune_session
755
+
756
+ if _in_tune_session():
757
+ import ray.tune
758
+
759
+ if _v2_migration_warnings_enabled():
760
+ _log_deprecation_warning(
761
+ "`ray.train.report` should be switched to "
762
+ "`ray.tune.report` when running in a function "
763
+ "passed to Ray Tune. This will be an error in the future."
764
+ )
765
+ return ray.tune.report(metrics, checkpoint=checkpoint)
766
+
767
+ get_session().report(metrics, checkpoint=checkpoint)
768
+
769
+
770
+ @PublicAPI(stability="stable")
771
+ @_warn_session_misuse()
772
+ def get_checkpoint() -> Optional[Checkpoint]:
773
+ """Access the latest reported checkpoint to resume from if one exists.
774
+
775
+ Example:
776
+
777
+ .. testcode::
778
+
779
+ import tempfile
780
+
781
+ from ray import train
782
+ from ray.train import Checkpoint
783
+ from ray.train.torch import TorchTrainer
784
+
785
+
786
+ def train_func(config):
787
+ start_epoch = 0
788
+ checkpoint = train.get_checkpoint()
789
+ if checkpoint:
790
+ with checkpoint.as_directory() as checkpoint_dir:
791
+ # Load back training state
792
+ ...
793
+
794
+ for epoch in range(start_epoch, config.get("num_epochs", 10)):
795
+ # Do training...
796
+
797
+ metrics = {"loss": ...}
798
+
799
+ with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
800
+ # Save the checkpoint...
801
+
802
+ checkpoint = Checkpoint.from_directory(temp_checkpoint_dir)
803
+ train.report(metrics, checkpoint=checkpoint)
804
+
805
+ trainer = TorchTrainer(
806
+ train_func, scaling_config=train.ScalingConfig(num_workers=2)
807
+ )
808
+
809
+ Returns:
810
+ Checkpoint object if the session is currently being resumed.
811
+ Otherwise, return None.
812
+ """
813
+ # If we are running in a Tune function, switch to `ray.tune.get_checkpoint`.
814
+ from ray.tune.trainable.trainable_fn_utils import _in_tune_session
815
+
816
+ if _in_tune_session():
817
+ import ray.tune
818
+
819
+ if _v2_migration_warnings_enabled():
820
+ _log_deprecation_warning(
821
+ "`ray.train.get_checkpoint` should be switched to "
822
+ "`ray.tune.get_checkpoint` when running in a function "
823
+ "passed to Ray Tune. This will be an error in the future."
824
+ )
825
+ return ray.tune.get_checkpoint()
826
+
827
+ return get_session().loaded_checkpoint
828
+
829
+
830
+ @PublicAPI(stability="beta")
831
+ @_warn_session_misuse()
832
+ def get_metadata() -> Dict[str, Any]:
833
+ """User metadata dict passed to the Trainer constructor."""
834
+ return get_session().metadata
835
+
836
+
837
+ @PublicAPI(stability="beta")
838
+ @_warn_session_misuse()
839
+ def get_experiment_name() -> str:
840
+ """Experiment name for the corresponding trial."""
841
+ return get_session().experiment_name
842
+
843
+
844
+ @PublicAPI(stability="beta")
845
+ @_warn_session_misuse()
846
+ def get_trial_name() -> str:
847
+ """Trial name for the corresponding trial."""
848
+ return get_session().trial_name
849
+
850
+
851
+ @PublicAPI(stability="beta")
852
+ @_warn_session_misuse()
853
+ def get_trial_id() -> str:
854
+ """Trial id for the corresponding trial."""
855
+ return get_session().trial_id
856
+
857
+
858
+ @PublicAPI(stability="alpha")
859
+ @_warn_session_misuse()
860
+ def get_run_id() -> str:
861
+ """Unique Train Run id for the corresponding trial."""
862
+ return get_session().run_id
863
+
864
+
865
+ @PublicAPI(stability="beta")
866
+ @_warn_session_misuse()
867
+ def get_trial_resources() -> "PlacementGroupFactory":
868
+ """Trial resources for the corresponding trial."""
869
+ return get_session().trial_resources
870
+
871
+
872
+ @PublicAPI(stability="beta")
873
+ @_warn_session_misuse()
874
+ def get_trial_dir() -> str:
875
+ """Log directory corresponding to the trial directory for a Tune session.
876
+ If calling from a Train session, this will give the trial directory of its parent
877
+ Tune session.
878
+
879
+ .. testcode::
880
+
881
+ from ray import train, tune
882
+
883
+ def train_func(config):
884
+ print(train.get_context().get_trial_dir())
885
+
886
+ tuner = tune.Tuner(train_func)
887
+ tuner.fit()
888
+
889
+ .. testoutput::
890
+ :options: +MOCK
891
+
892
+ /Users/root/ray_results/train_func_2023-07-19_15-01-37/train_func_d620c_00000_0_2023-07-19_15-01-40
893
+ """
894
+ return get_session().trial_dir
895
+
896
+
897
+ @PublicAPI(stability="beta")
898
+ @_warn_session_misuse(default_value=1)
899
+ def get_world_size() -> int:
900
+ """Get the current world size (i.e. total number of workers) for this run.
901
+
902
+ .. testcode::
903
+
904
+ import ray
905
+ from ray import train
906
+ from ray.train import ScalingConfig
907
+ from ray.train.tensorflow import TensorflowTrainer
908
+
909
+ NUM_WORKERS = 2
910
+
911
+ def train_loop_per_worker(config):
912
+ assert train.get_context().get_world_size() == NUM_WORKERS
913
+
914
+ train_dataset = ray.data.read_csv("s3://anonymous@ray-example-data/iris.csv")
915
+ trainer = TensorflowTrainer(
916
+ train_loop_per_worker,
917
+ scaling_config=ScalingConfig(num_workers=NUM_WORKERS),
918
+ datasets={"train": train_dataset}
919
+ )
920
+ trainer.fit()
921
+
922
+ .. testoutput::
923
+ :hide:
924
+
925
+ ...
926
+ """
927
+ session = get_session()
928
+ if not hasattr(session, "world_size"):
929
+ raise RuntimeError(
930
+ "`get_world_size` can only be called for TrainSession! "
931
+ "Make sure you only use that in `train_loop_per_worker` function"
932
+ "that is passed into `DataParallelTrainer`."
933
+ )
934
+ return session.world_size
935
+
936
+
937
+ @PublicAPI(stability="beta")
938
+ @_warn_session_misuse(default_value=0)
939
+ def get_world_rank() -> int:
940
+ """Get the world rank of this worker.
941
+
942
+ .. testcode::
943
+
944
+ import ray
945
+ from ray import train
946
+ from ray.train import ScalingConfig
947
+ from ray.train.tensorflow import TensorflowTrainer
948
+
949
+ def train_loop_per_worker(config):
950
+ if train.get_context().get_world_rank() == 0:
951
+ print("Worker 0")
952
+
953
+ train_dataset = ray.data.read_csv("s3://anonymous@ray-example-data/iris.csv")
954
+ trainer = TensorflowTrainer(
955
+ train_loop_per_worker,
956
+ scaling_config=ScalingConfig(num_workers=2),
957
+ datasets={"train": train_dataset}
958
+ )
959
+ trainer.fit()
960
+
961
+ .. testoutput::
962
+ :hide:
963
+
964
+ ...
965
+ """
966
+ session = get_session()
967
+ if not hasattr(session, "world_rank"):
968
+ raise RuntimeError(
969
+ "`get_world_rank` can only be called for TrainSession! "
970
+ "Make sure you only use that in `train_loop_per_worker` function"
971
+ "that is passed into `DataParallelTrainer`."
972
+ )
973
+ return session.world_rank
974
+
975
+
976
+ @PublicAPI(stability="beta")
977
+ @_warn_session_misuse(default_value=0)
978
+ def get_local_rank() -> int:
979
+ """Get the local rank of this worker (rank of the worker on its node).
980
+
981
+ .. testcode::
982
+
983
+ import torch
984
+
985
+ import ray
986
+ from ray import train
987
+ from ray.train import ScalingConfig
988
+ from ray.train.torch import TorchTrainer
989
+
990
+ def train_loop_per_worker(config):
991
+ if torch.cuda.is_available():
992
+ torch.cuda.set_device(train.get_context().get_local_rank())
993
+ ...
994
+
995
+ train_dataset = ray.data.read_csv("s3://anonymous@ray-example-data/iris.csv")
996
+ trainer = TorchTrainer(
997
+ train_loop_per_worker,
998
+ scaling_config=ScalingConfig(num_workers=2, use_gpu=True),
999
+ datasets={"train": train_dataset}
1000
+ )
1001
+ trainer.fit()
1002
+
1003
+ .. testoutput::
1004
+ :hide:
1005
+
1006
+ ...
1007
+ """
1008
+ session = get_session()
1009
+ if not hasattr(session, "local_rank"):
1010
+ raise RuntimeError(
1011
+ "`get_local_rank` can only be called for TrainSession! "
1012
+ "Make sure you only use that in `train_loop_per_worker` function"
1013
+ "that is passed into `DataParallelTrainer`."
1014
+ )
1015
+ return session.local_rank
1016
+
1017
+
1018
+ @PublicAPI(stability="beta")
1019
+ @_warn_session_misuse(default_value=0)
1020
+ def get_local_world_size() -> int:
1021
+ """Get the local world size of this node (i.e. number of workers on this node).
1022
+
1023
+ Example:
1024
+
1025
+ .. testcode::
1026
+
1027
+ import ray
1028
+ from ray import train
1029
+ from ray.train import ScalingConfig
1030
+ from ray.train.torch import TorchTrainer
1031
+
1032
+ def train_loop_per_worker():
1033
+ print(train.get_context().get_local_world_size())
1034
+
1035
+ train_dataset = ray.data.from_items(
1036
+ [{"x": x, "y": x + 1} for x in range(32)])
1037
+ trainer = TorchTrainer(train_loop_per_worker,
1038
+ scaling_config=ScalingConfig(num_workers=1),
1039
+ datasets={"train": train_dataset})
1040
+ trainer.fit()
1041
+
1042
+ .. testoutput::
1043
+ :hide:
1044
+
1045
+ ...
1046
+ """
1047
+ session = get_session()
1048
+ if not hasattr(session, "local_world_size"):
1049
+ raise RuntimeError(
1050
+ "`get_local_world_size` can only be called for TrainSession! "
1051
+ "Make sure you only use that in `train_loop_per_worker` function"
1052
+ "that is passed into `DataParallelTrainer`."
1053
+ )
1054
+ return session.local_world_size
1055
+
1056
+
1057
+ @PublicAPI(stability="beta")
1058
+ @_warn_session_misuse(default_value=0)
1059
+ def get_node_rank() -> int:
1060
+ """Get the rank of this node.
1061
+
1062
+ Example:
1063
+
1064
+ .. testcode::
1065
+
1066
+ import ray
1067
+ from ray import train
1068
+ from ray.train import ScalingConfig
1069
+ from ray.train.torch import TorchTrainer
1070
+
1071
+ def train_loop_per_worker():
1072
+ print(train.get_context().get_node_rank())
1073
+
1074
+ train_dataset = ray.data.from_items(
1075
+ [{"x": x, "y": x + 1} for x in range(32)])
1076
+ trainer = TorchTrainer(train_loop_per_worker,
1077
+ scaling_config=ScalingConfig(num_workers=1),
1078
+ datasets={"train": train_dataset})
1079
+ trainer.fit()
1080
+
1081
+ .. testoutput::
1082
+ :hide:
1083
+
1084
+ ...
1085
+ """
1086
+ session = get_session()
1087
+ if not hasattr(session, "node_rank"):
1088
+ raise RuntimeError(
1089
+ "`get_node_rank` can only be called for TrainSession! "
1090
+ "Make sure you only use that in `train_loop_per_worker` function"
1091
+ "that is passed into `DataParallelTrainer`."
1092
+ )
1093
+ return session.node_rank
1094
+
1095
+
1096
+ @PublicAPI(stability="stable")
1097
+ @_warn_session_misuse()
1098
+ def get_dataset_shard(
1099
+ dataset_name: Optional[str] = None,
1100
+ ) -> Optional["DataIterator"]:
1101
+ """Returns the :class:`ray.data.DataIterator` shard for this worker.
1102
+
1103
+ Call :meth:`~ray.data.DataIterator.iter_torch_batches` or
1104
+ :meth:`~ray.data.DataIterator.to_tf` on this shard to convert it to the
1105
+ appropriate framework-specific data type.
1106
+
1107
+ .. testcode::
1108
+
1109
+ import ray
1110
+ from ray import train
1111
+ from ray.train import ScalingConfig
1112
+ from ray.train.torch import TorchTrainer
1113
+
1114
+ def train_loop_per_worker(config):
1115
+ ...
1116
+ for epoch in range(2):
1117
+ # Trainer will automatically handle sharding.
1118
+ data_shard = train.get_dataset_shard("train")
1119
+ for batch in data_shard.iter_torch_batches():
1120
+ ...
1121
+
1122
+ train_dataset = ray.data.read_csv("s3://anonymous@ray-example-data/iris.csv")
1123
+ trainer = TorchTrainer(
1124
+ train_loop_per_worker,
1125
+ scaling_config=ScalingConfig(num_workers=2),
1126
+ datasets={"train": train_dataset}
1127
+ )
1128
+ trainer.fit()
1129
+
1130
+ .. testoutput::
1131
+ :hide:
1132
+
1133
+ ...
1134
+
1135
+ Args:
1136
+ dataset_name: If a Dictionary of Datasets was passed to ``Trainer``, then
1137
+ specifies which dataset shard to return.
1138
+
1139
+ Returns:
1140
+ The ``DataIterator`` shard to use for this worker.
1141
+ If no dataset is passed into Trainer, then return None.
1142
+ """
1143
+ session = get_session()
1144
+ if not hasattr(session, "get_dataset_shard"):
1145
+ raise RuntimeError(
1146
+ "`get_dataset_shard` can only be called for TrainSession! "
1147
+ "Make sure you only use that in `train_loop_per_worker` function"
1148
+ "that is passed into `DataParallelTrainer`."
1149
+ )
1150
+ return session.get_dataset_shard(dataset_name)
1151
+
1152
+
1153
+ @DeveloperAPI
1154
+ @_warn_session_misuse()
1155
+ def get_storage() -> StorageContext:
1156
+ """Returns the :class:`~ray.train._internal.storage.StorageContext` storage
1157
+ context which gives advanced access to the filesystem and paths
1158
+ configured through `RunConfig`.
1159
+
1160
+ NOTE: This is a developer API, and the `StorageContext` interface may change
1161
+ without notice between minor versions.
1162
+ """
1163
+ return get_session().storage
.venv/lib/python3.11/site-packages/ray/train/_internal/state/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ray.train._internal.state.state_manager import TrainRunStateManager
2
+
3
+ try:
4
+ import pydantic # noqa: F401
5
+ except ImportError:
6
+ raise ModuleNotFoundError(
7
+ "pydantic isn't installed."
8
+ "To install pydantic, please run 'pip install pydantic'"
9
+ )
10
+
11
+
12
+ __all__ = [
13
+ "TrainRunStateManager",
14
+ ]
.venv/lib/python3.11/site-packages/ray/train/_internal/state/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (581 Bytes). View file
 
.venv/lib/python3.11/site-packages/ray/train/_internal/state/__pycache__/schema.cpython-311.pyc ADDED
Binary file (8.55 kB). View file
 
.venv/lib/python3.11/site-packages/ray/train/_internal/state/__pycache__/state_actor.cpython-311.pyc ADDED
Binary file (3.44 kB). View file
 
.venv/lib/python3.11/site-packages/ray/train/_internal/state/__pycache__/state_manager.cpython-311.pyc ADDED
Binary file (6.71 kB). View file
 
.venv/lib/python3.11/site-packages/ray/train/_internal/state/schema.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+ from typing import List, Optional
3
+
4
+ from ray._private.pydantic_compat import BaseModel, Field
5
+ from ray.dashboard.modules.job.pydantic_models import JobDetails
6
+ from ray.util.annotations import DeveloperAPI
7
+
8
+ MAX_ERROR_STACK_TRACE_LENGTH = 50000
9
+
10
+
11
+ @DeveloperAPI
12
+ class RunStatusEnum(str, Enum):
13
+ """Enumeration for the status of a train run."""
14
+
15
+ # (Deprecated) Replaced by RUNNING.
16
+ # The train run has started
17
+ STARTED = "STARTED"
18
+ # The train run is running
19
+ RUNNING = "RUNNING"
20
+ # The train run was terminated as expected
21
+ FINISHED = "FINISHED"
22
+ # The train run was terminated early due to errors in the training function
23
+ ERRORED = "ERRORED"
24
+ # The train run was terminated early due to system errors or controller errors
25
+ ABORTED = "ABORTED"
26
+
27
+
28
+ @DeveloperAPI
29
+ class ActorStatusEnum(str, Enum):
30
+ DEAD = "DEAD"
31
+ ALIVE = "ALIVE"
32
+
33
+
34
+ @DeveloperAPI
35
+ class TrainWorkerInfo(BaseModel):
36
+ """Metadata of a Ray Train worker."""
37
+
38
+ actor_id: str = Field(description="Actor ID of the worker.")
39
+ world_rank: int = Field(description="World rank of the worker.")
40
+ local_rank: int = Field(description="Local rank of the worker.")
41
+ node_rank: int = Field(description="Node rank of the worker.")
42
+ node_id: str = Field(description="ID of the node that the worker is running on.")
43
+ node_ip: str = Field(
44
+ description="IP address of the node that the worker is running on."
45
+ )
46
+ pid: int = Field(description="Process ID of the worker.")
47
+ gpu_ids: List[int] = Field(
48
+ description="A list of GPU ids allocated to that worker."
49
+ )
50
+ status: Optional[ActorStatusEnum] = Field(
51
+ description="The status of the train worker actor. It can be ALIVE or DEAD."
52
+ )
53
+
54
+
55
+ @DeveloperAPI
56
+ class MemoryInfo(BaseModel):
57
+ rss: int
58
+ vms: int
59
+ pfaults: Optional[int]
60
+ pageins: Optional[int]
61
+
62
+
63
+ @DeveloperAPI
64
+ class ProcessStats(BaseModel):
65
+ cpuPercent: float
66
+ # total memory, free memory, memory used ratio
67
+ mem: Optional[List[int]]
68
+ memoryInfo: MemoryInfo
69
+
70
+
71
+ class ProcessGPUUsage(BaseModel):
72
+ # This gpu usage stats from a process
73
+ pid: int
74
+ gpuMemoryUsage: int
75
+
76
+
77
+ @DeveloperAPI
78
+ class GPUStats(BaseModel):
79
+ uuid: str
80
+ index: int
81
+ name: str
82
+ utilizationGpu: Optional[float]
83
+ memoryUsed: float
84
+ memoryTotal: float
85
+ processInfo: ProcessGPUUsage
86
+
87
+
88
+ @DeveloperAPI
89
+ class TrainWorkerInfoWithDetails(TrainWorkerInfo):
90
+ """Metadata of a Ray Train worker."""
91
+
92
+ processStats: Optional[ProcessStats] = Field(
93
+ None, description="Process stats of the worker."
94
+ )
95
+ gpus: List[GPUStats] = Field(
96
+ default_factory=list,
97
+ description=(
98
+ "GPU stats of the worker. "
99
+ "Only returns GPUs that are attached to the worker process."
100
+ ),
101
+ )
102
+
103
+
104
+ @DeveloperAPI
105
+ class TrainDatasetInfo(BaseModel):
106
+ name: str = Field(
107
+ description="The key of the dataset dict specified in Ray Train Trainer."
108
+ )
109
+ dataset_uuid: str = Field(description="The uuid of the dataset.")
110
+ dataset_name: Optional[str] = Field(description="The name of the dataset.")
111
+
112
+
113
+ @DeveloperAPI
114
+ class TrainRunInfo(BaseModel):
115
+ """Metadata for a Ray Train run and information about its workers."""
116
+
117
+ name: str = Field(description="The name of the Train run.")
118
+ id: str = Field(description="The unique identifier for each Train run.")
119
+ job_id: str = Field(description="The Ray Job ID.")
120
+ controller_actor_id: str = Field(description="Actor Id of the Train controller.")
121
+ workers: List[TrainWorkerInfo] = Field(
122
+ description="A List of Train workers sorted by global ranks."
123
+ )
124
+ datasets: List[TrainDatasetInfo] = Field(
125
+ description="A List of dataset info for this Train run."
126
+ )
127
+ run_status: RunStatusEnum = Field(
128
+ description="The current status of the train run. It can be one of the "
129
+ "following: RUNNING, FINISHED, ERRORED, or ABORTED."
130
+ )
131
+ status_detail: str = Field(
132
+ description="Detailed information about the current run status, "
133
+ "such as error messages."
134
+ )
135
+ start_time_ms: int = Field(
136
+ description="The UNIX timestamp of the start time of this Train run."
137
+ )
138
+ end_time_ms: Optional[int] = Field(
139
+ description="The UNIX timestamp of the end time of this Train run. "
140
+ "If null, the Train run has not ended yet."
141
+ )
142
+
143
+
144
+ @DeveloperAPI
145
+ class TrainRunInfoWithDetails(TrainRunInfo):
146
+ """Metadata for a Ray Train run and information about its workers."""
147
+
148
+ workers: List[TrainWorkerInfoWithDetails] = Field(
149
+ description="A List of Train workers sorted by global ranks."
150
+ )
151
+ job_details: Optional[JobDetails] = Field(
152
+ None, description="Details of the job that started this Train run."
153
+ )
154
+
155
+
156
+ @DeveloperAPI
157
+ class TrainRunsResponse(BaseModel):
158
+ train_runs: List[TrainRunInfoWithDetails]
.venv/lib/python3.11/site-packages/ray/train/_internal/state/state_actor.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import threading
3
+ from typing import Dict, Optional
4
+
5
+ import ray
6
+ from ray.actor import ActorHandle
7
+ from ray.train._internal.state.schema import TrainRunInfo
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ @ray.remote(num_cpus=0)
13
+ class TrainStateActor:
14
+ def __init__(self):
15
+ self._run_infos: Dict[str, TrainRunInfo] = {}
16
+
17
+ def register_train_run(self, run_info: TrainRunInfo) -> None:
18
+ # Register a new train run.
19
+ self._run_infos[run_info.id] = run_info
20
+
21
+ def get_train_run(self, run_id: str) -> Optional[TrainRunInfo]:
22
+ # Retrieve a registered run with its id
23
+ return self._run_infos.get(run_id, None)
24
+
25
+ def get_all_train_runs(self) -> Dict[str, TrainRunInfo]:
26
+ # Retrieve all registered train runs
27
+ return self._run_infos
28
+
29
+
30
+ TRAIN_STATE_ACTOR_NAME = "train_state_actor"
31
+ TRAIN_STATE_ACTOR_NAMESPACE = "_train_state_actor"
32
+
33
+ _state_actor_lock: threading.RLock = threading.RLock()
34
+
35
+
36
+ def get_or_create_state_actor() -> ActorHandle:
37
+ """Get or create a `TrainStateActor` on the head node."""
38
+ with _state_actor_lock:
39
+ state_actor = TrainStateActor.options(
40
+ name=TRAIN_STATE_ACTOR_NAME,
41
+ namespace=TRAIN_STATE_ACTOR_NAMESPACE,
42
+ get_if_exists=True,
43
+ lifetime="detached",
44
+ resources={"node:__internal_head__": 0.001},
45
+ # Escape from the parent's placement group
46
+ scheduling_strategy="DEFAULT",
47
+ ).remote()
48
+
49
+ # Ensure the state actor is ready
50
+ ray.get(state_actor.__ray_ready__.remote())
51
+ return state_actor
52
+
53
+
54
+ def get_state_actor() -> Optional[ActorHandle]:
55
+ """Get the `TrainStateActor` if exists, otherwise return None."""
56
+ try:
57
+ return ray.get_actor(
58
+ name=TRAIN_STATE_ACTOR_NAME,
59
+ namespace=TRAIN_STATE_ACTOR_NAMESPACE,
60
+ )
61
+ except ValueError:
62
+ return None
.venv/lib/python3.11/site-packages/ray/train/_internal/state/state_manager.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ from collections import defaultdict
4
+ from typing import Any, Dict
5
+
6
+ import ray
7
+ from ray.data import Dataset
8
+ from ray.train._internal.state.schema import (
9
+ RunStatusEnum,
10
+ TrainDatasetInfo,
11
+ TrainRunInfo,
12
+ TrainWorkerInfo,
13
+ )
14
+ from ray.train._internal.utils import check_for_failure
15
+ from ray.train._internal.worker_group import WorkerGroup
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ class TrainRunStateManager:
21
+ """A class that aggregates and reports train run info to TrainStateActor.
22
+
23
+ This manager class is created on the train controller layer for each run.
24
+ """
25
+
26
+ def __init__(self, state_actor) -> None:
27
+ self.state_actor = state_actor
28
+ self.train_run_info_dict = defaultdict(dict)
29
+
30
+ def register_train_run(
31
+ self,
32
+ run_id: str,
33
+ job_id: str,
34
+ run_name: str,
35
+ run_status: str,
36
+ controller_actor_id: str,
37
+ datasets: Dict[str, Dataset],
38
+ worker_group: WorkerGroup,
39
+ start_time_ms: float,
40
+ status_detail: str = "",
41
+ ) -> None:
42
+ """Collect Train Run Info and report to StateActor."""
43
+
44
+ if not self.state_actor:
45
+ logger.warning(
46
+ "Unable to register train run since `TrainStateActor` is not started."
47
+ )
48
+ return
49
+
50
+ def collect_train_worker_info():
51
+ train_context = ray.train.get_context()
52
+ core_context = ray.runtime_context.get_runtime_context()
53
+
54
+ return TrainWorkerInfo(
55
+ world_rank=train_context.get_world_rank(),
56
+ local_rank=train_context.get_local_rank(),
57
+ node_rank=train_context.get_node_rank(),
58
+ actor_id=core_context.get_actor_id(),
59
+ node_id=core_context.get_node_id(),
60
+ node_ip=ray.util.get_node_ip_address(),
61
+ gpu_ids=ray.get_gpu_ids(),
62
+ pid=os.getpid(),
63
+ )
64
+
65
+ futures = [
66
+ worker_group.execute_single_async(index, collect_train_worker_info)
67
+ for index in range(len(worker_group))
68
+ ]
69
+ success, exception = check_for_failure(futures)
70
+
71
+ if not success:
72
+ logger.error(
73
+ "Failed to collect run information from the Ray Train "
74
+ f"workers:\n{exception}"
75
+ )
76
+ return
77
+
78
+ worker_info_list = ray.get(futures)
79
+ worker_info_list = sorted(worker_info_list, key=lambda info: info.world_rank)
80
+
81
+ dataset_info_list = [
82
+ TrainDatasetInfo(
83
+ name=ds_name,
84
+ dataset_name=ds._plan._dataset_name,
85
+ dataset_uuid=ds._plan._dataset_uuid,
86
+ )
87
+ for ds_name, ds in datasets.items()
88
+ ]
89
+
90
+ updates = dict(
91
+ id=run_id,
92
+ job_id=job_id,
93
+ name=run_name,
94
+ controller_actor_id=controller_actor_id,
95
+ workers=worker_info_list,
96
+ datasets=dataset_info_list,
97
+ start_time_ms=start_time_ms,
98
+ run_status=run_status,
99
+ status_detail=status_detail,
100
+ )
101
+
102
+ # Clear the cached info to avoid registering the same run twice
103
+ self.train_run_info_dict[run_id] = {}
104
+ self._update_train_run_info(run_id, updates)
105
+
106
+ def end_train_run(
107
+ self,
108
+ run_id: str,
109
+ run_status: RunStatusEnum,
110
+ status_detail: str,
111
+ end_time_ms: int,
112
+ ):
113
+ """Update the train run status when the training is finished."""
114
+ updates = dict(
115
+ run_status=run_status,
116
+ status_detail=status_detail,
117
+ end_time_ms=end_time_ms,
118
+ )
119
+ self._update_train_run_info(run_id, updates)
120
+
121
+ def _update_train_run_info(self, run_id: str, updates: Dict[str, Any]) -> None:
122
+ """Update specific fields of a registered TrainRunInfo instance."""
123
+ if run_id in self.train_run_info_dict:
124
+ self.train_run_info_dict[run_id].update(updates)
125
+ train_run_info = TrainRunInfo(**self.train_run_info_dict[run_id])
126
+ ray.get(self.state_actor.register_train_run.remote(train_run_info))
.venv/lib/python3.11/site-packages/ray/train/_internal/storage.py ADDED
@@ -0,0 +1,725 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Try import ray[train] core requirements (defined in setup.py)
2
+ # isort: off
3
+ try:
4
+ import fsspec # noqa
5
+ from fsspec.implementations.local import LocalFileSystem
6
+
7
+ except (ImportError, ModuleNotFoundError) as e:
8
+ raise RuntimeError(
9
+ "fsspec is a required dependency of Ray Train and Ray Tune. "
10
+ "Please install with: `pip install fsspec`"
11
+ ) from e
12
+
13
+ try:
14
+ import pyarrow
15
+ import pyarrow.fs
16
+
17
+ except (ImportError, ModuleNotFoundError) as e:
18
+ raise RuntimeError(
19
+ "pyarrow is a required dependency of Ray Train and Ray Tune. "
20
+ "Please install with: `pip install pyarrow`"
21
+ ) from e
22
+
23
+ try:
24
+ # check if Arrow has S3 support
25
+ from pyarrow.fs import S3FileSystem
26
+ except ImportError:
27
+ S3FileSystem = None
28
+ # isort: on
29
+
30
+ import fnmatch
31
+ import logging
32
+ import os
33
+ import shutil
34
+ from pathlib import Path
35
+ from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Type, Union
36
+
37
+ from ray.air._internal.filelock import TempFileLock
38
+ from ray.train._internal.syncer import SyncConfig, Syncer, _BackgroundSyncer
39
+ from ray.train.constants import _get_ray_train_session_dir
40
+ from ray.util.annotations import DeveloperAPI
41
+
42
+ if TYPE_CHECKING:
43
+ from ray.train._checkpoint import Checkpoint
44
+
45
+
46
+ logger = logging.getLogger(__name__)
47
+
48
+
49
+ _VALIDATE_STORAGE_MARKER_FILENAME = ".validate_storage_marker"
50
+
51
+
52
+ class _ExcludingLocalFilesystem(LocalFileSystem):
53
+ """LocalFileSystem wrapper to exclude files according to patterns.
54
+
55
+ Args:
56
+ root_path: Root path to strip when matching with the exclude pattern.
57
+ Ex: root_path="/tmp/a/b/c", exclude=["*a*"], will exclude
58
+ /tmp/a/b/c/_a_.txt but not ALL of /tmp/a/*.
59
+ exclude: List of patterns that are applied to files returned by
60
+ ``self.find()``. If a file path matches this pattern, it will
61
+ be excluded.
62
+
63
+ """
64
+
65
+ def __init__(self, root_path: Path, exclude: List[str], **kwargs):
66
+ super().__init__(**kwargs)
67
+ self._exclude = exclude
68
+ self._root_path = root_path
69
+
70
+ @property
71
+ def fsid(self):
72
+ return "_excluding_local"
73
+
74
+ def _should_exclude(self, path: str) -> bool:
75
+ """Return True if `path` (relative to `root_path`) matches any of the
76
+ `self._exclude` patterns."""
77
+ path = Path(path)
78
+ relative_path = path.relative_to(self._root_path).as_posix()
79
+ match_candidates = [relative_path]
80
+ if path.is_dir():
81
+ # Everything is in posix path format ('/')
82
+ match_candidates.append(relative_path + "/")
83
+
84
+ for excl in self._exclude:
85
+ if any(fnmatch.fnmatch(candidate, excl) for candidate in match_candidates):
86
+ return True
87
+ return False
88
+
89
+ def find(self, path, maxdepth=None, withdirs=False, detail=False, **kwargs):
90
+ """Call parent find() and exclude from result."""
91
+ paths = super().find(
92
+ path, maxdepth=maxdepth, withdirs=withdirs, detail=detail, **kwargs
93
+ )
94
+ if detail:
95
+ return {
96
+ path: out
97
+ for path, out in paths.items()
98
+ if not self._should_exclude(path)
99
+ }
100
+ else:
101
+ return [path for path in paths if not self._should_exclude(path)]
102
+
103
+
104
+ def _pyarrow_fs_copy_files(
105
+ source, destination, source_filesystem=None, destination_filesystem=None, **kwargs
106
+ ):
107
+ if S3FileSystem and isinstance(destination_filesystem, pyarrow.fs.S3FileSystem):
108
+ # Workaround multi-threading issue with pyarrow. Note that use_threads=True
109
+ # is safe for download, just not for uploads, see:
110
+ # https://github.com/apache/arrow/issues/32372
111
+ kwargs.setdefault("use_threads", False)
112
+
113
+ # Use a large chunk size to speed up large checkpoint transfers.
114
+ kwargs.setdefault("chunk_size", 64 * 1024 * 1024)
115
+
116
+ return pyarrow.fs.copy_files(
117
+ source,
118
+ destination,
119
+ source_filesystem=source_filesystem,
120
+ destination_filesystem=destination_filesystem,
121
+ **kwargs,
122
+ )
123
+
124
+
125
+ # TODO(justinvyu): Add unit tests for all these utils.
126
+
127
+
128
+ def _delete_fs_path(fs: pyarrow.fs.FileSystem, fs_path: str):
129
+ is_dir = _is_directory(fs, fs_path)
130
+
131
+ try:
132
+ if is_dir:
133
+ fs.delete_dir(fs_path)
134
+ else:
135
+ fs.delete_file(fs_path)
136
+ except Exception:
137
+ logger.exception(f"Caught exception when deleting path at ({fs}, {fs_path}):")
138
+
139
+
140
+ def _download_from_fs_path(
141
+ fs: pyarrow.fs.FileSystem,
142
+ fs_path: str,
143
+ local_path: str,
144
+ filelock: bool = True,
145
+ ):
146
+ """Downloads a directory or file from (fs, fs_path) to a local path.
147
+
148
+ If fs_path points to a directory:
149
+ - The full directory contents are downloaded directly into `local_path`,
150
+ rather than to a subdirectory of `local_path`.
151
+
152
+ If fs_path points to a file:
153
+ - The file is downloaded to `local_path`, which is expected to be a file path.
154
+
155
+ If the download fails, the `local_path` contents are
156
+ cleaned up before raising, if the directory did not previously exist.
157
+
158
+ NOTE: This method creates `local_path`'s parent directories if they do not
159
+ already exist. If the download fails, this does NOT clean up all the parent
160
+ directories that were created.
161
+
162
+ Args:
163
+ fs: The filesystem to download from.
164
+ fs_path: The filesystem path (either a directory or a file) to download.
165
+ local_path: The local path to download to.
166
+ filelock: Whether to require a file lock before downloading, useful for
167
+ multiple downloads to the same directory that may be happening in parallel.
168
+
169
+ Raises:
170
+ FileNotFoundError: if (fs, fs_path) doesn't exist.
171
+ """
172
+
173
+ _local_path = Path(local_path).resolve()
174
+ exists_before = _local_path.exists()
175
+ if _is_directory(fs=fs, fs_path=fs_path):
176
+ _local_path.mkdir(parents=True, exist_ok=True)
177
+ else:
178
+ _local_path.parent.mkdir(parents=True, exist_ok=True)
179
+
180
+ try:
181
+ if filelock:
182
+ with TempFileLock(f"{os.path.normpath(local_path)}.lock"):
183
+ _pyarrow_fs_copy_files(fs_path, local_path, source_filesystem=fs)
184
+ else:
185
+ _pyarrow_fs_copy_files(fs_path, local_path, source_filesystem=fs)
186
+ except Exception as e:
187
+ # Clean up the directory if downloading was unsuccessful
188
+ if not exists_before:
189
+ shutil.rmtree(local_path, ignore_errors=True)
190
+ raise e
191
+
192
+
193
+ def _upload_to_fs_path(
194
+ local_path: str,
195
+ fs: pyarrow.fs.FileSystem,
196
+ fs_path: str,
197
+ exclude: Optional[List[str]] = None,
198
+ ) -> None:
199
+ """Uploads a local directory or file to (fs, fs_path).
200
+
201
+ NOTE: This will create all necessary parent directories at the destination.
202
+
203
+ Args:
204
+ local_path: The local path to upload.
205
+ fs: The filesystem to upload to.
206
+ fs_path: The filesystem path where the dir/file will be uploaded to.
207
+ exclude: A list of filename matches to exclude from upload. This includes
208
+ all files under subdirectories as well.
209
+ This pattern will match with the relative paths of all files under
210
+ `local_path`.
211
+ Ex: ["*.png"] to exclude all .png images.
212
+ """
213
+
214
+ if not exclude:
215
+ # TODO(justinvyu): uploading a single file doesn't work
216
+ # (since we always create a directory at fs_path)
217
+ _create_directory(fs=fs, fs_path=fs_path)
218
+ _pyarrow_fs_copy_files(local_path, fs_path, destination_filesystem=fs)
219
+ return
220
+
221
+ _upload_to_uri_with_exclude_fsspec(
222
+ local_path=local_path, fs=fs, fs_path=fs_path, exclude=exclude
223
+ )
224
+
225
+
226
+ def _upload_to_uri_with_exclude_fsspec(
227
+ local_path: str, fs: "pyarrow.fs", fs_path: str, exclude: Optional[List[str]]
228
+ ) -> None:
229
+ local_fs = _ExcludingLocalFilesystem(root_path=local_path, exclude=exclude)
230
+ handler = pyarrow.fs.FSSpecHandler(local_fs)
231
+ source_fs = pyarrow.fs.PyFileSystem(handler)
232
+
233
+ _create_directory(fs=fs, fs_path=fs_path)
234
+ _pyarrow_fs_copy_files(
235
+ local_path, fs_path, source_filesystem=source_fs, destination_filesystem=fs
236
+ )
237
+
238
+
239
+ def _list_at_fs_path(
240
+ fs: pyarrow.fs.FileSystem,
241
+ fs_path: str,
242
+ file_filter: Optional[Callable[[pyarrow.fs.FileInfo], bool]] = None,
243
+ ) -> List[str]:
244
+ """Returns the list of filenames at (fs, fs_path), similar to os.listdir.
245
+
246
+ If the path doesn't exist, returns an empty list.
247
+ """
248
+ if file_filter is None:
249
+ file_filter = lambda x: True # noqa: E731
250
+
251
+ selector = pyarrow.fs.FileSelector(fs_path, allow_not_found=True, recursive=False)
252
+ return [
253
+ os.path.relpath(file_info.path.lstrip("/"), start=fs_path.lstrip("/"))
254
+ for file_info in fs.get_file_info(selector)
255
+ if file_filter(file_info)
256
+ ]
257
+
258
+
259
+ def _exists_at_fs_path(fs: pyarrow.fs.FileSystem, fs_path: str) -> bool:
260
+ """Returns True if (fs, fs_path) exists."""
261
+
262
+ valid = fs.get_file_info(fs_path)
263
+ return valid.type != pyarrow.fs.FileType.NotFound
264
+
265
+
266
+ def _is_directory(fs: pyarrow.fs.FileSystem, fs_path: str) -> bool:
267
+ """Checks if (fs, fs_path) is a directory or a file.
268
+
269
+ Raises:
270
+ FileNotFoundError: if (fs, fs_path) doesn't exist.
271
+ """
272
+
273
+ file_info = fs.get_file_info(fs_path)
274
+ if file_info.type == pyarrow.fs.FileType.NotFound:
275
+ raise FileNotFoundError(f"Path not found: ({fs}, {fs_path})")
276
+
277
+ return not file_info.is_file
278
+
279
+
280
+ def _create_directory(fs: pyarrow.fs.FileSystem, fs_path: str) -> None:
281
+ """Create directory at (fs, fs_path).
282
+
283
+ Some external filesystems require directories to already exist, or at least
284
+ the `netloc` to be created (e.g. PyArrows ``mock://`` filesystem).
285
+
286
+ Generally this should be done before and outside of Ray applications. This
287
+ utility is thus primarily used in testing, e.g. of ``mock://` URIs.
288
+ """
289
+ try:
290
+ fs.create_dir(fs_path)
291
+ except Exception:
292
+ logger.exception(
293
+ f"Caught exception when creating directory at ({fs}, {fs_path}):"
294
+ )
295
+
296
+
297
+ def get_fs_and_path(
298
+ storage_path: Union[str, os.PathLike],
299
+ storage_filesystem: Optional[pyarrow.fs.FileSystem] = None,
300
+ ) -> Tuple[pyarrow.fs.FileSystem, str]:
301
+ """Returns the fs and path from a storage path and an optional custom fs.
302
+
303
+ Args:
304
+ storage_path: A storage path or URI. (ex: s3://bucket/path or /tmp/ray_results)
305
+ storage_filesystem: A custom filesystem to use. If not provided,
306
+ this will be auto-resolved by pyarrow. If provided, the storage_path
307
+ is assumed to be prefix-stripped already, and must be a valid path
308
+ on the filesystem.
309
+ """
310
+ storage_path = str(storage_path)
311
+
312
+ if storage_filesystem:
313
+ return storage_filesystem, storage_path
314
+
315
+ return pyarrow.fs.FileSystem.from_uri(storage_path)
316
+
317
+
318
+ class _FilesystemSyncer(_BackgroundSyncer):
319
+ """Syncer between local filesystem and a `storage_filesystem`."""
320
+
321
+ def __init__(self, storage_filesystem: Optional["pyarrow.fs.FileSystem"], **kwargs):
322
+ self.storage_filesystem = storage_filesystem
323
+ super().__init__(**kwargs)
324
+
325
+ def _sync_up_command(
326
+ self, local_path: str, uri: str, exclude: Optional[List] = None
327
+ ) -> Tuple[Callable, Dict]:
328
+ # TODO(justinvyu): Defer this cleanup up as part of the
329
+ # external-facing Syncer deprecation.
330
+ fs_path = uri
331
+ return (
332
+ _upload_to_fs_path,
333
+ dict(
334
+ local_path=local_path,
335
+ fs=self.storage_filesystem,
336
+ fs_path=fs_path,
337
+ exclude=exclude,
338
+ ),
339
+ )
340
+
341
+ def _sync_down_command(self, uri: str, local_path: str) -> Tuple[Callable, Dict]:
342
+ fs_path = uri
343
+ return (
344
+ _download_from_fs_path,
345
+ dict(
346
+ fs=self.storage_filesystem,
347
+ fs_path=fs_path,
348
+ local_path=local_path,
349
+ ),
350
+ )
351
+
352
+ def _delete_command(self, uri: str) -> Tuple[Callable, Dict]:
353
+ fs_path = uri
354
+ return _delete_fs_path, dict(fs=self.storage_filesystem, fs_path=fs_path)
355
+
356
+
357
+ @DeveloperAPI
358
+ class StorageContext:
359
+ """Shared context that holds the source of truth for all paths and
360
+ storage utilities, passed along from the driver to workers.
361
+
362
+ This object defines a few types of paths:
363
+ 1. *_fs_path: A path on the `storage_filesystem`. This is a regular path
364
+ which has been prefix-stripped by pyarrow.fs.FileSystem.from_uri and
365
+ can be joined with `Path(...).as_posix()`.
366
+ 2. *_driver_staging_path: The temporary staging directory on the local filesystem
367
+ where driver artifacts are saved to before persisting them to storage.
368
+ 3. trial_working_directory: The local filesystem path that the remote
369
+ actors' working directories are moved to by default.
370
+ This is separated from the driver staging path so that driver syncing
371
+ does not implicitly upload the trial working directory, for trials on the
372
+ driver node.
373
+
374
+ Example with storage_path="mock:///bucket/path?param=1":
375
+
376
+ >>> import ray
377
+ >>> from ray.train._internal.storage import StorageContext
378
+ >>> import os
379
+ >>> _ = ray.init()
380
+ >>> storage = StorageContext(
381
+ ... storage_path="mock://netloc/bucket/path?param=1",
382
+ ... experiment_dir_name="exp_name",
383
+ ... )
384
+ >>> storage.storage_filesystem # Auto-resolved # doctest: +ELLIPSIS
385
+ <pyarrow._fs._MockFileSystem object...
386
+ >>> storage.experiment_fs_path
387
+ 'bucket/path/exp_name'
388
+ >>> storage.experiment_driver_staging_path # doctest: +ELLIPSIS
389
+ '/tmp/ray/session_.../artifacts/.../exp_name/driver_artifacts'
390
+ >>> storage.trial_dir_name = "trial_dir"
391
+ >>> storage.trial_fs_path
392
+ 'bucket/path/exp_name/trial_dir'
393
+ >>> storage.trial_driver_staging_path # doctest: +ELLIPSIS
394
+ '/tmp/ray/session_.../artifacts/.../exp_name/driver_artifacts/trial_dir'
395
+ >>> storage.trial_working_directory # doctest: +ELLIPSIS
396
+ '/tmp/ray/session_.../artifacts/.../exp_name/working_dirs/trial_dir'
397
+ >>> storage.current_checkpoint_index = 1
398
+ >>> storage.checkpoint_fs_path
399
+ 'bucket/path/exp_name/trial_dir/checkpoint_000001'
400
+ >>> ray.shutdown()
401
+
402
+ Example with storage_path="/tmp/ray_results":
403
+
404
+ >>> from ray.train._internal.storage import StorageContext
405
+ >>> storage = StorageContext(
406
+ ... storage_path="/tmp/ray_results",
407
+ ... experiment_dir_name="exp_name",
408
+ ... )
409
+ >>> storage.storage_fs_path
410
+ '/tmp/ray_results'
411
+ >>> storage.experiment_fs_path
412
+ '/tmp/ray_results/exp_name'
413
+ >>> storage.storage_filesystem # Auto-resolved # doctest: +ELLIPSIS
414
+ <pyarrow._fs.LocalFileSystem object...
415
+
416
+ Internal Usage Examples:
417
+ - To copy files to the trial directory on the storage filesystem:
418
+
419
+ pyarrow.fs.copy_files(
420
+ local_dir,
421
+ Path(storage.trial_fs_path, "subdir").as_posix(),
422
+ destination_filesystem=storage.filesystem
423
+ )
424
+
425
+ .. warning::
426
+ This is an experimental developer API and is subject to change
427
+ without notice between versions.
428
+ """
429
+
430
+ def __init__(
431
+ self,
432
+ storage_path: Union[str, os.PathLike],
433
+ experiment_dir_name: str,
434
+ sync_config: Optional[SyncConfig] = None,
435
+ storage_filesystem: Optional[pyarrow.fs.FileSystem] = None,
436
+ trial_dir_name: Optional[str] = None,
437
+ current_checkpoint_index: int = -1,
438
+ ):
439
+ from ray.tune.utils import date_str
440
+
441
+ self.custom_fs_provided = storage_filesystem is not None
442
+
443
+ # Invariant: (`storage_filesystem`, `storage_path`) is the location where
444
+ # *all* results can be accessed.
445
+ self.experiment_dir_name = experiment_dir_name
446
+ self.trial_dir_name = trial_dir_name
447
+ self.current_checkpoint_index = current_checkpoint_index
448
+ self.sync_config = sync_config or SyncConfig()
449
+
450
+ self.storage_filesystem, self.storage_fs_path = get_fs_and_path(
451
+ storage_path, storage_filesystem
452
+ )
453
+ self.storage_fs_path = Path(self.storage_fs_path).as_posix()
454
+
455
+ self.syncer: Syncer = _FilesystemSyncer(
456
+ storage_filesystem=self.storage_filesystem,
457
+ sync_period=self.sync_config.sync_period,
458
+ sync_timeout=self.sync_config.sync_timeout,
459
+ )
460
+
461
+ self._create_validation_file()
462
+ self._check_validation_file()
463
+
464
+ # Timestamp is used to create a unique session directory for the current
465
+ # training job. This is used to avoid conflicts when multiple training jobs
466
+ # run with the same name in the same cluster.
467
+ # This is set ONCE at the creation of the storage context, on the driver.
468
+ self._timestamp = date_str()
469
+
470
+ def __str__(self):
471
+ return (
472
+ "StorageContext<\n"
473
+ f" storage_filesystem='{self.storage_filesystem.type_name}',\n"
474
+ f" storage_fs_path='{self.storage_fs_path}',\n"
475
+ f" experiment_dir_name='{self.experiment_dir_name}',\n"
476
+ f" trial_dir_name='{self.trial_dir_name}',\n"
477
+ f" current_checkpoint_index={self.current_checkpoint_index},\n"
478
+ ">"
479
+ )
480
+
481
+ def _create_validation_file(self):
482
+ """On the creation of a storage context, create a validation file at the
483
+ storage path to verify that the storage path can be written to.
484
+ This validation file is also used to check whether the storage path is
485
+ accessible by all nodes in the cluster."""
486
+ valid_file = Path(
487
+ self.experiment_fs_path, _VALIDATE_STORAGE_MARKER_FILENAME
488
+ ).as_posix()
489
+ self.storage_filesystem.create_dir(self.experiment_fs_path)
490
+ with self.storage_filesystem.open_output_stream(valid_file):
491
+ pass
492
+
493
+ def _check_validation_file(self):
494
+ """Checks that the validation file exists at the storage path."""
495
+ valid_file = Path(
496
+ self.experiment_fs_path, _VALIDATE_STORAGE_MARKER_FILENAME
497
+ ).as_posix()
498
+ if not _exists_at_fs_path(fs=self.storage_filesystem, fs_path=valid_file):
499
+ raise RuntimeError(
500
+ f"Unable to set up cluster storage with the following settings:\n{self}"
501
+ "\nCheck that all nodes in the cluster have read/write access "
502
+ "to the configured storage path. `RunConfig(storage_path)` should be "
503
+ "set to a cloud storage URI or a shared filesystem path accessible "
504
+ "by all nodes in your cluster ('s3://bucket' or '/mnt/nfs'). "
505
+ "A local path on the head node is not accessible by worker nodes. "
506
+ "See: https://docs.ray.io/en/latest/train/user-guides/persistent-storage.html" # noqa: E501
507
+ )
508
+
509
+ def _update_checkpoint_index(self, metrics: Dict):
510
+ # Per default, increase by 1. This can be overwritten to customize checkpoint
511
+ # directories.
512
+ self.current_checkpoint_index += 1
513
+
514
+ def persist_current_checkpoint(self, checkpoint: "Checkpoint") -> "Checkpoint":
515
+ """Persists a given checkpoint to the current checkpoint path on the filesystem.
516
+
517
+ "Current" is defined by the `current_checkpoint_index` attribute of the
518
+ storage context.
519
+
520
+ This method copies the checkpoint files to the storage location.
521
+ It's up to the user to delete the original checkpoint files if desired.
522
+
523
+ For example, the original directory is typically a local temp directory.
524
+
525
+ Args:
526
+ checkpoint: The checkpoint to persist to (fs, checkpoint_fs_path).
527
+
528
+ Returns:
529
+ Checkpoint: A Checkpoint pointing to the persisted checkpoint location.
530
+ """
531
+ # TODO(justinvyu): Fix this cyclical import.
532
+ from ray.train._checkpoint import Checkpoint
533
+
534
+ logger.debug(
535
+ "Copying checkpoint files to storage path:\n"
536
+ "({source_fs}, {source}) -> ({dest_fs}, {destination})".format(
537
+ source=checkpoint.path,
538
+ destination=self.checkpoint_fs_path,
539
+ source_fs=checkpoint.filesystem,
540
+ dest_fs=self.storage_filesystem,
541
+ )
542
+ )
543
+
544
+ # Raise an error if the storage path is not accessible when
545
+ # attempting to upload a checkpoint from a remote worker.
546
+ # Ex: If storage_path is a local path, then a validation marker
547
+ # will only exist on the head node but not the worker nodes.
548
+ self._check_validation_file()
549
+
550
+ self.storage_filesystem.create_dir(self.checkpoint_fs_path)
551
+ _pyarrow_fs_copy_files(
552
+ source=checkpoint.path,
553
+ destination=self.checkpoint_fs_path,
554
+ source_filesystem=checkpoint.filesystem,
555
+ destination_filesystem=self.storage_filesystem,
556
+ )
557
+
558
+ persisted_checkpoint = Checkpoint(
559
+ filesystem=self.storage_filesystem,
560
+ path=self.checkpoint_fs_path,
561
+ )
562
+ logger.info(f"Checkpoint successfully created at: {persisted_checkpoint}")
563
+ return persisted_checkpoint
564
+
565
+ def persist_artifacts(self, force: bool = False) -> None:
566
+ """Persists all artifacts within `trial_local_dir` to storage.
567
+
568
+ This method possibly launches a background task to sync the trial dir,
569
+ depending on the `sync_period` + `sync_artifacts_on_checkpoint`
570
+ settings of `SyncConfig`.
571
+
572
+ `(local_fs, trial_working_dir) -> (storage_filesystem, trial_fs_path)`
573
+
574
+ Args:
575
+ force: If True, wait for a previous sync to finish, launch a new one,
576
+ and wait for that one to finish. By the end of a `force=True` call, the
577
+ latest version of the trial artifacts will be persisted.
578
+ """
579
+ if not self.sync_config.sync_artifacts:
580
+ return
581
+
582
+ # Skip if there are no artifacts to sync
583
+ is_empty = not any(os.scandir(self.trial_working_directory))
584
+ if is_empty:
585
+ return
586
+
587
+ if force:
588
+ self.syncer.wait()
589
+ self.syncer.sync_up(
590
+ local_dir=self.trial_working_directory, remote_dir=self.trial_fs_path
591
+ )
592
+ self.syncer.wait()
593
+ else:
594
+ self.syncer.sync_up_if_needed(
595
+ local_dir=self.trial_working_directory, remote_dir=self.trial_fs_path
596
+ )
597
+
598
+ @property
599
+ def experiment_fs_path(self) -> str:
600
+ """The path on the `storage_filesystem` to the experiment directory.
601
+
602
+ NOTE: This does not have a URI prefix anymore, since it has been stripped
603
+ by pyarrow.fs.FileSystem.from_uri already. The URI scheme information is
604
+ kept in `storage_filesystem` instead.
605
+ """
606
+ return Path(self.storage_fs_path, self.experiment_dir_name).as_posix()
607
+
608
+ def _get_session_path(self) -> str:
609
+ """The Ray Train/Tune session local directory used to stage files
610
+ before persisting to the storage filesystem."""
611
+ return Path(
612
+ _get_ray_train_session_dir(), self._timestamp, self.experiment_dir_name
613
+ ).as_posix()
614
+
615
+ @property
616
+ def experiment_driver_staging_path(self) -> str:
617
+ """The local filesystem path of the experiment directory on the driver node.
618
+
619
+ The driver is the node where `Trainer.fit`/`Tuner.fit` is being called.
620
+
621
+ This path is of the form:
622
+ `/tmp/ray/session_<session_id>/artifacts/<ray-train-job-timestamp>/
623
+ <experiment_dir_name>/driver_artifacts`
624
+
625
+ This should be used as the temporary staging location for files *on the driver*
626
+ before syncing them to `experiment_fs_path`.
627
+ For example, the search algorithm should dump its state to this directory.
628
+ See `trial_driver_staging_path` for writing trial-specific artifacts.
629
+
630
+ The directory is synced to
631
+ `{storage_path}/{experiment_dir_name}` periodically.
632
+ See `_ExperimentCheckpointManager.checkpoint` for where that happens.
633
+ """
634
+ return Path(self._get_session_path(), "driver_artifacts").as_posix()
635
+
636
+ @property
637
+ def trial_fs_path(self) -> str:
638
+ """The trial directory path on the `storage_filesystem`.
639
+
640
+ Raises a ValueError if `trial_dir_name` is not set beforehand.
641
+ """
642
+ if self.trial_dir_name is None:
643
+ raise RuntimeError(
644
+ "Should not access `trial_fs_path` without setting `trial_dir_name`"
645
+ )
646
+ return Path(self.experiment_fs_path, self.trial_dir_name).as_posix()
647
+
648
+ @property
649
+ def trial_driver_staging_path(self) -> str:
650
+ """The local filesystem path of the trial directory on the driver.
651
+
652
+ The driver is the node where `Trainer.fit`/`Tuner.fit` is being called.
653
+
654
+ This path is of the form:
655
+ `/tmp/ray/session_<session_id>/artifacts/<ray-train-job-timestamp>/
656
+ <experiment_dir_name>/driver_artifacts/<trial_dir_name>`
657
+
658
+ This should be used as the temporary location for files on the driver
659
+ before persisting them to `trial_fs_path`.
660
+
661
+ For example, callbacks (e.g., JsonLoggerCallback) should write trial-specific
662
+ logfiles within this directory.
663
+ """
664
+ if self.trial_dir_name is None:
665
+ raise RuntimeError(
666
+ "Should not access `trial_driver_staging_path` "
667
+ "without setting `trial_dir_name`"
668
+ )
669
+ return Path(self.experiment_driver_staging_path, self.trial_dir_name).as_posix()
670
+
671
+ @property
672
+ def trial_working_directory(self) -> str:
673
+ """The local filesystem path to trial working directory.
674
+
675
+ This path is of the form:
676
+ `/tmp/ray/session_<session_id>/artifacts/<ray-train-job-timestamp>/
677
+ <experiment_dir_name>/working_dirs/<trial_dir_name>`
678
+
679
+ Ray Train/Tune moves the remote actor's working directory to this path
680
+ by default, unless disabled by `RAY_CHDIR_TO_TRIAL_DIR` environment variable.
681
+
682
+ Writing files to this directory allows users to persist training artifacts
683
+ if `SyncConfig(sync_artifacts=True)` is set.
684
+ """
685
+ if self.trial_dir_name is None:
686
+ raise RuntimeError(
687
+ "Cannot access `trial_working_directory` without "
688
+ "setting `trial_dir_name`"
689
+ )
690
+ return Path(
691
+ self._get_session_path(), "working_dirs", self.trial_dir_name
692
+ ).as_posix()
693
+
694
+ @property
695
+ def checkpoint_fs_path(self) -> str:
696
+ """The current checkpoint directory path on the `storage_filesystem`.
697
+
698
+ "Current" refers to the checkpoint that is currently being created/persisted.
699
+ The user of this class is responsible for setting the `current_checkpoint_index`
700
+ (e.g., incrementing when needed).
701
+ """
702
+ return Path(self.trial_fs_path, self.checkpoint_dir_name).as_posix()
703
+
704
+ @property
705
+ def checkpoint_dir_name(self) -> str:
706
+ """The current checkpoint directory name, based on the checkpoint index."""
707
+ return StorageContext._make_checkpoint_dir_name(self.current_checkpoint_index)
708
+
709
+ @staticmethod
710
+ def get_experiment_dir_name(run_obj: Union[str, Callable, Type]) -> str:
711
+ from ray.tune.experiment import Experiment
712
+ from ray.tune.utils import date_str
713
+
714
+ run_identifier = Experiment.get_trainable_name(run_obj)
715
+
716
+ if bool(int(os.environ.get("TUNE_DISABLE_DATED_SUBDIR", 0))):
717
+ dir_name = run_identifier
718
+ else:
719
+ dir_name = "{}_{}".format(run_identifier, date_str())
720
+ return dir_name
721
+
722
+ @staticmethod
723
+ def _make_checkpoint_dir_name(index: int):
724
+ """Get the name of the checkpoint directory, given an index."""
725
+ return f"checkpoint_{index:06d}"
.venv/lib/python3.11/site-packages/ray/train/_internal/syncer.py ADDED
@@ -0,0 +1,490 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+ import logging
3
+ import threading
4
+ import time
5
+ import traceback
6
+ from dataclasses import dataclass
7
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
8
+
9
+ from ray._private.thirdparty.tabulate.tabulate import tabulate
10
+ from ray.train.constants import _DEPRECATED_VALUE
11
+ from ray.util.annotations import DeveloperAPI, PublicAPI
12
+ from ray.widgets import Template
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+ # Syncing period for syncing checkpoints between nodes or to cloud.
17
+ DEFAULT_SYNC_PERIOD = 300
18
+
19
+ # Default sync timeout after which syncing processes are aborted
20
+ DEFAULT_SYNC_TIMEOUT = 1800
21
+
22
+
23
+ @PublicAPI(stability="stable")
24
+ @dataclass
25
+ class SyncConfig:
26
+ """Configuration object for Train/Tune file syncing to `RunConfig(storage_path)`.
27
+
28
+ In Ray Train/Tune, here is where syncing (mainly uploading) happens:
29
+
30
+ The experiment driver (on the head node) syncs the experiment directory to storage
31
+ (which includes experiment state such as searcher state, the list of trials
32
+ and their statuses, and trial metadata).
33
+
34
+ It's also possible to sync artifacts from the trial directory to storage
35
+ by setting `sync_artifacts=True`.
36
+ For a Ray Tune run with many trials, each trial will upload its trial directory
37
+ to storage, which includes arbitrary files that you dumped during the run.
38
+ For a Ray Train run doing distributed training, each remote worker will similarly
39
+ upload its trial directory to storage.
40
+
41
+ See :ref:`persistent-storage-guide` for more details and examples.
42
+
43
+ Args:
44
+ sync_period: Minimum time in seconds to wait between two sync operations.
45
+ A smaller ``sync_period`` will have the data in storage updated more often
46
+ but introduces more syncing overhead. Defaults to 5 minutes.
47
+ sync_timeout: Maximum time in seconds to wait for a sync process
48
+ to finish running. A sync operation will run for at most this long
49
+ before raising a `TimeoutError`. Defaults to 30 minutes.
50
+ sync_artifacts: [Beta] Whether or not to sync artifacts that are saved to the
51
+ trial directory (accessed via `train.get_context().get_trial_dir()`)
52
+ to the persistent storage configured via `train.RunConfig(storage_path)`.
53
+ The trial or remote worker will try to launch an artifact syncing
54
+ operation every time `train.report` happens, subject to `sync_period`
55
+ and `sync_artifacts_on_checkpoint`.
56
+ Defaults to False -- no artifacts are persisted by default.
57
+ sync_artifacts_on_checkpoint: If True, trial/worker artifacts are
58
+ forcefully synced on every reported checkpoint.
59
+ This only has an effect if `sync_artifacts` is True.
60
+ Defaults to True.
61
+ """
62
+
63
+ sync_period: int = DEFAULT_SYNC_PERIOD
64
+ sync_timeout: int = DEFAULT_SYNC_TIMEOUT
65
+ sync_artifacts: bool = False
66
+ sync_artifacts_on_checkpoint: bool = True
67
+ upload_dir: Optional[str] = _DEPRECATED_VALUE
68
+ syncer: Optional[Union[str, "Syncer"]] = _DEPRECATED_VALUE
69
+ sync_on_checkpoint: bool = _DEPRECATED_VALUE
70
+
71
+ # TODO(justinvyu): [Deprecated] Remove in 2.11.
72
+ def _deprecation_warning(self, attr_name: str, extra_msg: str):
73
+ if getattr(self, attr_name) != _DEPRECATED_VALUE:
74
+ raise DeprecationWarning(
75
+ f"`SyncConfig({attr_name})` is a deprecated configuration "
76
+ "Please remove it from your `SyncConfig`. "
77
+ f"{extra_msg}"
78
+ )
79
+
80
+ def __post_init__(self):
81
+ for attr_name, extra_msg in [
82
+ (
83
+ "upload_dir",
84
+ "\nPlease specify `ray.train.RunConfig(storage_path)` instead.",
85
+ ),
86
+ (
87
+ "syncer",
88
+ "\nPlease implement custom syncing logic with a custom "
89
+ "`pyarrow.fs.FileSystem` instead, and pass it into "
90
+ "`ray.train.RunConfig(storage_filesystem)`. "
91
+ "See here: https://docs.ray.io/en/latest/train/user-guides/persistent-storage.html#custom-storage", # noqa: E501
92
+ ),
93
+ ("sync_on_checkpoint", ""),
94
+ ]:
95
+ self._deprecation_warning(attr_name, extra_msg)
96
+
97
+ def _repr_html_(self) -> str:
98
+ """Generate an HTML representation of the SyncConfig."""
99
+ return Template("scrollableTable.html.j2").render(
100
+ table=tabulate(
101
+ {
102
+ "Setting": ["Sync period", "Sync timeout"],
103
+ "Value": [self.sync_period, self.sync_timeout],
104
+ },
105
+ tablefmt="html",
106
+ showindex=False,
107
+ headers="keys",
108
+ ),
109
+ max_height="none",
110
+ )
111
+
112
+
113
+ class _BackgroundProcess:
114
+ def __init__(self, fn: Callable):
115
+ self._fn = fn
116
+ self._process = None
117
+ self._result = {}
118
+ self._start_time = float("-inf")
119
+
120
+ @property
121
+ def is_running(self):
122
+ return self._process and self._process.is_alive()
123
+
124
+ @property
125
+ def start_time(self):
126
+ return self._start_time
127
+
128
+ def start(self, *args, **kwargs):
129
+ if self.is_running:
130
+ return False
131
+
132
+ self._result = {}
133
+
134
+ def entrypoint():
135
+ try:
136
+ result = self._fn(*args, **kwargs)
137
+ except Exception as e:
138
+ self._result["exception"] = e
139
+ return
140
+
141
+ self._result["result"] = result
142
+
143
+ self._process = threading.Thread(target=entrypoint)
144
+ self._process.daemon = True
145
+ self._process.start()
146
+ self._start_time = time.time()
147
+
148
+ def wait(self, timeout: Optional[float] = None) -> Any:
149
+ """Waits for the background process to finish running. Waits until the
150
+ background process has run for at least `timeout` seconds, counting from
151
+ the time when the process was started."""
152
+ if not self._process:
153
+ return None
154
+
155
+ time_remaining = None
156
+ if timeout:
157
+ elapsed = time.time() - self.start_time
158
+ time_remaining = max(timeout - elapsed, 0)
159
+
160
+ self._process.join(timeout=time_remaining)
161
+
162
+ if self._process.is_alive():
163
+ self._process = None
164
+ raise TimeoutError(
165
+ f"{getattr(self._fn, '__name__', str(self._fn))} did not finish "
166
+ f"running within the timeout of {timeout} seconds."
167
+ )
168
+
169
+ self._process = None
170
+
171
+ exception = self._result.get("exception")
172
+ if exception:
173
+ raise exception
174
+
175
+ result = self._result.get("result")
176
+
177
+ self._result = {}
178
+ return result
179
+
180
+
181
+ @DeveloperAPI
182
+ class Syncer(abc.ABC):
183
+ """Syncer class for synchronizing data between Ray nodes and remote (cloud) storage.
184
+
185
+ This class handles data transfer for two cases:
186
+
187
+ 1. Synchronizing data such as experiment state snapshots from the driver to
188
+ cloud storage.
189
+ 2. Synchronizing data such as trial checkpoints from remote trainables to
190
+ cloud storage.
191
+
192
+ Synchronizing tasks are usually asynchronous and can be awaited using ``wait()``.
193
+ The base class implements a ``wait_or_retry()`` API that will retry a failed
194
+ sync command.
195
+
196
+ The base class also exposes an API to only kick off syncs every ``sync_period``
197
+ seconds.
198
+
199
+ Args:
200
+ sync_period: The minimum time in seconds between sync operations, as
201
+ used by ``sync_up/down_if_needed``.
202
+ sync_timeout: The maximum time to wait for a sync process to finish before
203
+ issuing a new sync operation. Ex: should be used by ``wait`` if launching
204
+ asynchronous sync tasks.
205
+ """
206
+
207
+ def __init__(
208
+ self,
209
+ sync_period: float = DEFAULT_SYNC_PERIOD,
210
+ sync_timeout: float = DEFAULT_SYNC_TIMEOUT,
211
+ ):
212
+ self.sync_period = sync_period
213
+ self.sync_timeout = sync_timeout
214
+ self.last_sync_up_time = float("-inf")
215
+ self.last_sync_down_time = float("-inf")
216
+
217
+ @abc.abstractmethod
218
+ def sync_up(
219
+ self, local_dir: str, remote_dir: str, exclude: Optional[List] = None
220
+ ) -> bool:
221
+ """Synchronize local directory to remote directory.
222
+
223
+ This function can spawn an asynchronous process that can be awaited in
224
+ ``wait()``.
225
+
226
+ Args:
227
+ local_dir: Local directory to sync from.
228
+ remote_dir: Remote directory to sync up to. This is an URI
229
+ (``protocol://remote/path``).
230
+ exclude: Pattern of files to exclude, e.g.
231
+ ``["*/checkpoint_*]`` to exclude trial checkpoints.
232
+
233
+ Returns:
234
+ True if sync process has been spawned, False otherwise.
235
+
236
+ """
237
+ raise NotImplementedError
238
+
239
+ @abc.abstractmethod
240
+ def sync_down(
241
+ self, remote_dir: str, local_dir: str, exclude: Optional[List] = None
242
+ ) -> bool:
243
+ """Synchronize remote directory to local directory.
244
+
245
+ This function can spawn an asynchronous process that can be awaited in
246
+ ``wait()``.
247
+
248
+ Args:
249
+ remote_dir: Remote directory to sync down from. This is an URI
250
+ (``protocol://remote/path``).
251
+ local_dir: Local directory to sync to.
252
+ exclude: Pattern of files to exclude, e.g.
253
+ ``["*/checkpoint_*]`` to exclude trial checkpoints.
254
+
255
+ Returns:
256
+ True if sync process has been spawned, False otherwise.
257
+
258
+ """
259
+ raise NotImplementedError
260
+
261
+ @abc.abstractmethod
262
+ def delete(self, remote_dir: str) -> bool:
263
+ """Delete directory on remote storage.
264
+
265
+ This function can spawn an asynchronous process that can be awaited in
266
+ ``wait()``.
267
+
268
+ Args:
269
+ remote_dir: Remote directory to delete. This is an URI
270
+ (``protocol://remote/path``).
271
+
272
+ Returns:
273
+ True if sync process has been spawned, False otherwise.
274
+
275
+ """
276
+ raise NotImplementedError
277
+
278
+ def retry(self):
279
+ """Retry the last sync up, sync down, or delete command.
280
+
281
+ You should implement this method if you spawn asynchronous syncing
282
+ processes.
283
+ """
284
+ pass
285
+
286
+ def wait(self, timeout: Optional[float] = None):
287
+ """Wait for asynchronous sync command to finish.
288
+
289
+ You should implement this method if you spawn asynchronous syncing
290
+ processes. This method should timeout after the asynchronous command
291
+ has run for `sync_timeout` seconds and raise a `TimeoutError`.
292
+ """
293
+ pass
294
+
295
+ def sync_up_if_needed(
296
+ self, local_dir: str, remote_dir: str, exclude: Optional[List] = None
297
+ ) -> bool:
298
+ """Syncs up if time since last sync up is greater than sync_period.
299
+
300
+ Args:
301
+ local_dir: Local directory to sync from.
302
+ remote_dir: Remote directory to sync up to. This is an URI
303
+ (``protocol://remote/path``).
304
+ exclude: Pattern of files to exclude, e.g.
305
+ ``["*/checkpoint_*]`` to exclude trial checkpoints.
306
+ """
307
+ now = time.time()
308
+ if now - self.last_sync_up_time >= self.sync_period:
309
+ result = self.sync_up(
310
+ local_dir=local_dir, remote_dir=remote_dir, exclude=exclude
311
+ )
312
+ self.last_sync_up_time = now
313
+ return result
314
+
315
+ def sync_down_if_needed(
316
+ self, remote_dir: str, local_dir: str, exclude: Optional[List] = None
317
+ ):
318
+ """Syncs down if time since last sync down is greater than sync_period.
319
+
320
+ Args:
321
+ remote_dir: Remote directory to sync down from. This is an URI
322
+ (``protocol://remote/path``).
323
+ local_dir: Local directory to sync to.
324
+ exclude: Pattern of files to exclude, e.g.
325
+ ``["*/checkpoint_*]`` to exclude trial checkpoints.
326
+ """
327
+ now = time.time()
328
+ if now - self.last_sync_down_time >= self.sync_period:
329
+ result = self.sync_down(
330
+ remote_dir=remote_dir, local_dir=local_dir, exclude=exclude
331
+ )
332
+ self.last_sync_down_time = now
333
+ return result
334
+
335
+ def wait_or_retry(self, max_retries: int = 2, backoff_s: int = 5):
336
+ assert max_retries > 0
337
+ last_error_traceback = None
338
+ for i in range(max_retries + 1):
339
+ try:
340
+ self.wait()
341
+ except Exception as e:
342
+ attempts_remaining = max_retries - i
343
+
344
+ # If we're out of retries, then save the full traceback of the last
345
+ # error and show it when raising an exception.
346
+ if attempts_remaining == 0:
347
+ last_error_traceback = traceback.format_exc()
348
+ break
349
+
350
+ logger.error(
351
+ f"The latest sync operation failed with the following error: "
352
+ f"{repr(e)}\n"
353
+ f"Retrying {attempts_remaining} more time(s) after sleeping "
354
+ f"for {backoff_s} seconds..."
355
+ )
356
+ time.sleep(backoff_s)
357
+ self.retry()
358
+ continue
359
+ # Succeeded!
360
+ return
361
+ raise RuntimeError(
362
+ f"Failed sync even after {max_retries} retries. "
363
+ f"The latest sync failed with the following error:\n{last_error_traceback}"
364
+ )
365
+
366
+ def reset(self):
367
+ self.last_sync_up_time = float("-inf")
368
+ self.last_sync_down_time = float("-inf")
369
+
370
+ def close(self):
371
+ pass
372
+
373
+ def _repr_html_(self) -> str:
374
+ return
375
+
376
+
377
+ class _BackgroundSyncer(Syncer):
378
+ """Syncer using a background process for asynchronous file transfer."""
379
+
380
+ def __init__(
381
+ self,
382
+ sync_period: float = DEFAULT_SYNC_PERIOD,
383
+ sync_timeout: float = DEFAULT_SYNC_TIMEOUT,
384
+ ):
385
+ super(_BackgroundSyncer, self).__init__(
386
+ sync_period=sync_period, sync_timeout=sync_timeout
387
+ )
388
+ self._sync_process = None
389
+ self._current_cmd = None
390
+
391
+ def _should_continue_existing_sync(self):
392
+ """Returns whether a previous sync is still running within the timeout."""
393
+ return (
394
+ self._sync_process
395
+ and self._sync_process.is_running
396
+ and time.time() - self._sync_process.start_time < self.sync_timeout
397
+ )
398
+
399
+ def _launch_sync_process(self, sync_command: Tuple[Callable, Dict]):
400
+ """Waits for the previous sync process to finish,
401
+ then launches a new process that runs the given command."""
402
+ if self._sync_process:
403
+ try:
404
+ self.wait()
405
+ except Exception:
406
+ logger.warning(
407
+ f"Last sync command failed with the following error:\n"
408
+ f"{traceback.format_exc()}"
409
+ )
410
+
411
+ self._current_cmd = sync_command
412
+ self.retry()
413
+
414
+ def sync_up(
415
+ self, local_dir: str, remote_dir: str, exclude: Optional[List] = None
416
+ ) -> bool:
417
+ if self._should_continue_existing_sync():
418
+ logger.debug(
419
+ f"Last sync still in progress, "
420
+ f"skipping sync up of {local_dir} to {remote_dir}"
421
+ )
422
+ return False
423
+
424
+ sync_up_cmd = self._sync_up_command(
425
+ local_path=local_dir, uri=remote_dir, exclude=exclude
426
+ )
427
+ self._launch_sync_process(sync_up_cmd)
428
+
429
+ return True
430
+
431
+ def _sync_up_command(
432
+ self, local_path: str, uri: str, exclude: Optional[List] = None
433
+ ) -> Tuple[Callable, Dict]:
434
+ raise NotImplementedError
435
+
436
+ def sync_down(
437
+ self, remote_dir: str, local_dir: str, exclude: Optional[List] = None
438
+ ) -> bool:
439
+ if self._should_continue_existing_sync():
440
+ logger.warning(
441
+ f"Last sync still in progress, "
442
+ f"skipping sync down of {remote_dir} to {local_dir}"
443
+ )
444
+ return False
445
+
446
+ sync_down_cmd = self._sync_down_command(uri=remote_dir, local_path=local_dir)
447
+ self._launch_sync_process(sync_down_cmd)
448
+
449
+ return True
450
+
451
+ def _sync_down_command(self, uri: str, local_path: str) -> Tuple[Callable, Dict]:
452
+ raise NotImplementedError
453
+
454
+ def delete(self, remote_dir: str) -> bool:
455
+ if self._should_continue_existing_sync():
456
+ logger.warning(
457
+ f"Last sync still in progress, skipping deletion of {remote_dir}"
458
+ )
459
+ return False
460
+
461
+ delete_cmd = self._delete_command(uri=remote_dir)
462
+ self._launch_sync_process(delete_cmd)
463
+
464
+ return True
465
+
466
+ def _delete_command(self, uri: str) -> Tuple[Callable, Dict]:
467
+ raise NotImplementedError
468
+
469
+ def wait(self, timeout: Optional[float] = None):
470
+ if self._sync_process:
471
+ try:
472
+ self._sync_process.wait(timeout=timeout or self.sync_timeout)
473
+ except Exception as e:
474
+ raise e
475
+ finally:
476
+ # Regardless of whether the sync process succeeded within the timeout,
477
+ # clear the sync process so a new one can be created.
478
+ self._sync_process = None
479
+
480
+ def retry(self):
481
+ if not self._current_cmd:
482
+ raise RuntimeError("No sync command set, cannot retry.")
483
+ cmd, kwargs = self._current_cmd
484
+ self._sync_process = _BackgroundProcess(cmd)
485
+ self._sync_process.start(**kwargs)
486
+
487
+ def __getstate__(self):
488
+ state = self.__dict__.copy()
489
+ state["_sync_process"] = None
490
+ return state
.venv/lib/python3.11/site-packages/ray/train/_internal/utils.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+ import functools
3
+ import inspect
4
+ import logging
5
+ import os
6
+ from pathlib import Path
7
+ from typing import (
8
+ Any,
9
+ Callable,
10
+ ContextManager,
11
+ Dict,
12
+ List,
13
+ Optional,
14
+ Tuple,
15
+ TypeVar,
16
+ Union,
17
+ )
18
+
19
+ import ray
20
+ from ray.actor import ActorHandle
21
+ from ray.air._internal.util import (
22
+ StartTraceback,
23
+ StartTracebackWithWorkerRank,
24
+ find_free_port,
25
+ )
26
+ from ray.exceptions import RayActorError
27
+ from ray.types import ObjectRef
28
+
29
+ T = TypeVar("T")
30
+
31
+ logger = logging.getLogger(__name__)
32
+
33
+
34
+ def check_for_failure(
35
+ remote_values: List[ObjectRef],
36
+ ) -> Tuple[bool, Optional[Exception]]:
37
+ """Check for actor failure when retrieving the remote values.
38
+
39
+ Args:
40
+ remote_values: List of object references from Ray actor methods.
41
+
42
+ Returns:
43
+ A tuple of (bool, Exception). The bool is
44
+ True if evaluating all object references is successful, False otherwise.
45
+ """
46
+ unfinished = remote_values.copy()
47
+
48
+ while len(unfinished) > 0:
49
+ finished, unfinished = ray.wait(unfinished)
50
+
51
+ # If a failure occurs the ObjectRef will be marked as finished.
52
+ # Calling ray.get will expose the failure as a RayActorError.
53
+ for object_ref in finished:
54
+ # Everything in finished has either failed or completed
55
+ # successfully.
56
+ try:
57
+ ray.get(object_ref)
58
+ except RayActorError as exc:
59
+ failed_actor_rank = remote_values.index(object_ref)
60
+ logger.info(f"Worker {failed_actor_rank} has failed.")
61
+ return False, exc
62
+ except Exception as exc:
63
+ # Other (e.g. training) errors should be directly raised
64
+ failed_worker_rank = remote_values.index(object_ref)
65
+ raise StartTracebackWithWorkerRank(
66
+ worker_rank=failed_worker_rank
67
+ ) from exc
68
+
69
+ return True, None
70
+
71
+
72
+ def get_address_and_port() -> Tuple[str, int]:
73
+ """Returns the IP address and a free port on this node."""
74
+ addr = ray.util.get_node_ip_address()
75
+ port = find_free_port()
76
+
77
+ return addr, port
78
+
79
+
80
+ def construct_path(path: Path, parent_path: Path) -> Path:
81
+ """Constructs a path relative to a parent.
82
+
83
+ Args:
84
+ path: A relative or absolute path.
85
+ parent_path: A relative path or absolute path.
86
+
87
+ Returns: An absolute path.
88
+ """
89
+ if path.expanduser().is_absolute():
90
+ return path.expanduser().resolve()
91
+ else:
92
+ return parent_path.joinpath(path).expanduser().resolve()
93
+
94
+
95
+ def update_env_vars(env_vars: Dict[str, Any]):
96
+ """Updates the environment variables on this worker process.
97
+
98
+ Args:
99
+ env_vars: Environment variables to set.
100
+ """
101
+ sanitized = {k: str(v) for k, v in env_vars.items()}
102
+ os.environ.update(sanitized)
103
+
104
+
105
+ def count_required_parameters(fn: Callable) -> int:
106
+ """Counts the number of required parameters of a function.
107
+
108
+ NOTE: *args counts as 1 required parameter.
109
+
110
+ Examples
111
+ --------
112
+
113
+ >>> def fn(a, b, /, c, *args, d=1, e=2, **kwargs):
114
+ ... pass
115
+ >>> count_required_parameters(fn)
116
+ 4
117
+
118
+ >>> fn = lambda: 1
119
+ >>> count_required_parameters(fn)
120
+ 0
121
+
122
+ >>> def fn(config, a, b=1, c=2):
123
+ ... pass
124
+ >>> from functools import partial
125
+ >>> count_required_parameters(partial(fn, a=0))
126
+ 1
127
+ """
128
+ params = inspect.signature(fn).parameters.values()
129
+
130
+ positional_param_kinds = {
131
+ inspect.Parameter.POSITIONAL_ONLY,
132
+ inspect.Parameter.POSITIONAL_OR_KEYWORD,
133
+ inspect.Parameter.VAR_POSITIONAL,
134
+ }
135
+ return len(
136
+ [
137
+ p
138
+ for p in params
139
+ if p.default == inspect.Parameter.empty and p.kind in positional_param_kinds
140
+ ]
141
+ )
142
+
143
+
144
+ def construct_train_func(
145
+ train_func: Union[Callable[[], T], Callable[[Dict[str, Any]], T]],
146
+ config: Optional[Dict[str, Any]],
147
+ train_func_context: ContextManager,
148
+ fn_arg_name: Optional[str] = "train_func",
149
+ discard_returns: bool = False,
150
+ ) -> Callable[[], T]:
151
+ """Validates and constructs the training function to execute.
152
+ Args:
153
+ train_func: The training function to execute.
154
+ This can either take in no arguments or a ``config`` dict.
155
+ config (Optional[Dict]): Configurations to pass into
156
+ ``train_func``. If None then an empty Dict will be created.
157
+ train_func_context: Context manager for user's `train_func`, which executes
158
+ backend-specific logic before and after the training function.
159
+ fn_arg_name (Optional[str]): The name of training function to use for error
160
+ messages.
161
+ discard_returns: Whether to discard any returns from train_func or not.
162
+ Returns:
163
+ A valid training function.
164
+ Raises:
165
+ ValueError: if the input ``train_func`` is invalid.
166
+ """
167
+ num_required_params = count_required_parameters(train_func)
168
+
169
+ if discard_returns:
170
+ # Discard any returns from the function so that
171
+ # BackendExecutor doesn't try to deserialize them.
172
+ # Those returns are inaccesible with AIR anyway.
173
+ @functools.wraps(train_func)
174
+ def discard_return_wrapper(*args, **kwargs):
175
+ try:
176
+ train_func(*args, **kwargs)
177
+ except Exception as e:
178
+ raise StartTraceback from e
179
+
180
+ wrapped_train_func = discard_return_wrapper
181
+ else:
182
+ wrapped_train_func = train_func
183
+
184
+ if num_required_params > 1:
185
+ err_msg = (
186
+ f"{fn_arg_name} should take in 0 or 1 required arguments, but it accepts "
187
+ f"{num_required_params} required arguments instead."
188
+ )
189
+ raise ValueError(err_msg)
190
+ elif num_required_params == 1:
191
+ config = {} if config is None else config
192
+
193
+ @functools.wraps(wrapped_train_func)
194
+ def train_fn():
195
+ try:
196
+ with train_func_context():
197
+ return wrapped_train_func(config)
198
+ except Exception as e:
199
+ raise StartTraceback from e
200
+
201
+ else: # num_params == 0
202
+
203
+ @functools.wraps(wrapped_train_func)
204
+ def train_fn():
205
+ try:
206
+ with train_func_context():
207
+ return wrapped_train_func()
208
+ except Exception as e:
209
+ raise StartTraceback from e
210
+
211
+ return train_fn
212
+
213
+
214
+ class Singleton(abc.ABCMeta):
215
+ """Singleton Abstract Base Class
216
+
217
+ https://stackoverflow.com/questions/33364070/implementing
218
+ -singleton-as-metaclass-but-for-abstract-classes
219
+ """
220
+
221
+ _instances = {}
222
+
223
+ def __call__(cls, *args, **kwargs):
224
+ if cls not in cls._instances:
225
+ cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
226
+ return cls._instances[cls]
227
+
228
+
229
+ class ActorWrapper:
230
+ """Wraps an actor to provide same API as using the base class directly."""
231
+
232
+ def __init__(self, actor: ActorHandle):
233
+ self.actor = actor
234
+
235
+ def __getattr__(self, item):
236
+ # The below will fail if trying to access an attribute (not a method) from the
237
+ # actor.
238
+ actor_method = getattr(self.actor, item)
239
+ return lambda *args, **kwargs: ray.get(actor_method.remote(*args, **kwargs))
.venv/lib/python3.11/site-packages/ray/train/_internal/worker_group.py ADDED
@@ -0,0 +1,426 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import socket
4
+ from collections import defaultdict
5
+ from dataclasses import dataclass
6
+ from typing import Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union
7
+
8
+ import ray
9
+ from ray.actor import ActorHandle
10
+ from ray.air._internal.util import exception_cause, skip_exceptions
11
+ from ray.types import ObjectRef
12
+ from ray.util.placement_group import PlacementGroup
13
+
14
+ T = TypeVar("T")
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ class RayTrainWorker:
20
+ """A class to execute arbitrary functions. Does not hold any state."""
21
+
22
+ def __execute(self, func: Callable[..., T], *args, **kwargs) -> T:
23
+ """Executes the input function and returns the output.
24
+
25
+ Args:
26
+ func: The function to execute.
27
+ args, kwargs: The arguments to pass into func.
28
+ """
29
+ try:
30
+ return func(*args, **kwargs)
31
+ except Exception as e:
32
+ skipped = skip_exceptions(e)
33
+ raise skipped from exception_cause(skipped)
34
+
35
+
36
+ @dataclass
37
+ class WorkerMetadata:
38
+ """Metadata for each worker/actor.
39
+
40
+ This information is expected to stay the same throughout the lifetime of
41
+ actor.
42
+
43
+ Args:
44
+ node_id: ID of the node this worker is on.
45
+ node_ip: IP address of the node this worker is on.
46
+ hostname: Hostname that this worker is on.
47
+ resource_ids: Map of accelerator resources
48
+ ("GPU", "neuron_cores", ..) to their IDs.
49
+ pid: Process ID of this worker.
50
+ """
51
+
52
+ node_id: str
53
+ node_ip: str
54
+ hostname: str
55
+ resource_ids: Dict[str, List[str]]
56
+ pid: int
57
+
58
+
59
+ @dataclass
60
+ class Worker:
61
+ """Class representing a Worker."""
62
+
63
+ actor: ActorHandle
64
+ metadata: WorkerMetadata
65
+
66
+
67
+ def create_executable_class(executable_cls: Optional[Type] = None) -> Type:
68
+ """Create the executable class to use as the Ray actors."""
69
+ if not executable_cls:
70
+ return RayTrainWorker
71
+ elif issubclass(executable_cls, RayTrainWorker):
72
+ return executable_cls
73
+ else:
74
+
75
+ class _WrappedExecutable(executable_cls, RayTrainWorker):
76
+ def __init__(self, *args, **kwargs):
77
+ super().__init__(*args, **kwargs)
78
+
79
+ return _WrappedExecutable
80
+
81
+
82
+ def construct_metadata() -> WorkerMetadata:
83
+ """Creates metadata for this worker.
84
+
85
+ This function is expected to be run on the actor.
86
+ """
87
+ node_id = ray.get_runtime_context().get_node_id()
88
+ node_ip = ray.util.get_node_ip_address()
89
+ hostname = socket.gethostname()
90
+ accelerator_ids = ray.get_runtime_context().get_accelerator_ids()
91
+ pid = os.getpid()
92
+
93
+ return WorkerMetadata(
94
+ node_id=node_id,
95
+ node_ip=node_ip,
96
+ hostname=hostname,
97
+ resource_ids=accelerator_ids,
98
+ pid=pid,
99
+ )
100
+
101
+
102
+ class WorkerGroup:
103
+ """Group of Ray Actors that can execute arbitrary functions.
104
+
105
+ ``WorkerGroup`` launches Ray actors according to the given
106
+ specification. It can then execute arbitrary Python functions in each of
107
+ these workers.
108
+
109
+ If not enough resources are available to launch the actors, the Ray
110
+ cluster will automatically scale up if autoscaling is enabled.
111
+
112
+ Args:
113
+ num_workers: The number of workers (Ray actors) to launch.
114
+ Defaults to 1.
115
+ resources_per_worker (Optional[Dict[str, float]]):
116
+ Dictionary specifying the resources that will be
117
+ requested for each worker. Defaults to {"CPU": 1}.
118
+ actor_cls (Optional[Type]): If specified use this class as the
119
+ remote actors.
120
+ remote_cls_args, remote_cls_kwargs: If ``remote_cls`` is provided,
121
+ these args will be used for the worker initialization.
122
+ placement_group (PlacementGroup|str): The placement group that workers
123
+ should be created in. Defaults to "default" which will inherit the
124
+ parent placement group (if child tasks should be captured).
125
+
126
+
127
+ Example:
128
+
129
+ .. code_block:: python
130
+
131
+ worker_group = WorkerGroup(num_workers=2)
132
+ output = worker_group.execute(lambda: 1)
133
+ assert len(output) == 2
134
+ assert all(o == 1 for o in output)
135
+ """
136
+
137
+ def __init__(
138
+ self,
139
+ num_workers: int = 1,
140
+ resources_per_worker: Optional[Dict[str, float]] = None,
141
+ actor_cls: Type = None,
142
+ actor_cls_args: Optional[Tuple] = None,
143
+ actor_cls_kwargs: Optional[Dict] = None,
144
+ placement_group: Union[PlacementGroup, str] = "default",
145
+ ):
146
+ if resources_per_worker is None:
147
+ resources_per_worker = {"CPU": 1}
148
+ else:
149
+ resources_per_worker = resources_per_worker.copy()
150
+
151
+ if num_workers <= 0:
152
+ raise ValueError(
153
+ "The provided `num_workers` must be greater "
154
+ f"than 0. Received num_workers={num_workers} "
155
+ f"instead."
156
+ )
157
+
158
+ if any(v < 0 for v in resources_per_worker.values()):
159
+ raise ValueError(
160
+ "The number of resources per worker must not be negative. "
161
+ f"Received resources_per_worker={resources_per_worker}."
162
+ )
163
+
164
+ if (actor_cls_args or actor_cls_kwargs) and not actor_cls:
165
+ raise ValueError(
166
+ "`actor_cls_args` or `actor_class_kwargs` are "
167
+ "passed in but no `actor_cls` is passed in."
168
+ )
169
+
170
+ self.num_workers = num_workers
171
+ self.num_cpus_per_worker = resources_per_worker.pop("CPU", 0)
172
+ self.num_gpus_per_worker = resources_per_worker.pop("GPU", 0)
173
+ self.memory_per_worker = resources_per_worker.pop("memory", 0)
174
+ self.workers = []
175
+ self._base_cls = create_executable_class(actor_cls)
176
+ assert issubclass(self._base_cls, RayTrainWorker)
177
+
178
+ self._actor_cls_args = actor_cls_args or []
179
+ self._actor_cls_kwargs = actor_cls_kwargs or {}
180
+
181
+ self._placement_group = placement_group
182
+
183
+ # TODO(matt): Validate resources. Fast-fail if it is impossible to
184
+ # handle the request, rather than hang indefinitely.
185
+ self._remote_cls = ray.remote(
186
+ num_cpus=self.num_cpus_per_worker,
187
+ num_gpus=self.num_gpus_per_worker,
188
+ memory=self.memory_per_worker,
189
+ resources=resources_per_worker,
190
+ )(self._base_cls)
191
+ self.start()
192
+
193
+ def start(self):
194
+ """Starts all the workers in this worker group."""
195
+ if self.workers and len(self.workers) > 0:
196
+ raise RuntimeError(
197
+ "The workers have already been started. "
198
+ "Please call `shutdown` first if you want to "
199
+ "restart them."
200
+ )
201
+
202
+ logger.debug(f"Starting {self.num_workers} workers.")
203
+ self.add_workers(self.num_workers)
204
+ logger.debug(f"{len(self.workers)} workers have successfully started.")
205
+
206
+ def shutdown(self, patience_s: float = 5):
207
+ """Shutdown all the workers in this worker group.
208
+
209
+ Args:
210
+ patience_s: Attempt a graceful shutdown
211
+ of the workers for this many seconds. Fallback to force kill
212
+ if graceful shutdown is not complete after this time. If
213
+ this is less than or equal to 0, immediately force kill all
214
+ workers.
215
+ """
216
+ logger.debug(f"Shutting down {len(self.workers)} workers.")
217
+ if patience_s <= 0:
218
+ for worker in self.workers:
219
+ ray.kill(worker.actor)
220
+ else:
221
+ done_refs = [w.actor.__ray_terminate__.remote() for w in self.workers]
222
+ # Wait for actors to die gracefully.
223
+ done, not_done = ray.wait(done_refs, timeout=patience_s)
224
+ if not_done:
225
+ logger.debug("Graceful termination failed. Falling back to force kill.")
226
+ # If all actors are not able to die gracefully, then kill them.
227
+ for worker in self.workers:
228
+ ray.kill(worker.actor)
229
+
230
+ logger.debug("Shutdown successful.")
231
+ self.workers = []
232
+
233
+ def execute_async(self, func: Callable[..., T], *args, **kwargs) -> List[ObjectRef]:
234
+ """Execute ``func`` on each worker and return the futures.
235
+
236
+ Args:
237
+ func: A function to call on each worker.
238
+ args, kwargs: Passed directly into func.
239
+
240
+ Returns:
241
+ (List[ObjectRef]) A list of ``ObjectRef`` representing the
242
+ output of ``func`` from each worker. The order is the same
243
+ as ``self.workers``.
244
+
245
+ """
246
+ if len(self.workers) <= 0:
247
+ raise RuntimeError(
248
+ "There are no active workers. This worker "
249
+ "group has most likely been shut down. Please"
250
+ "create a new WorkerGroup or restart this one."
251
+ )
252
+
253
+ return [
254
+ w.actor._RayTrainWorker__execute.options(
255
+ name=f"_RayTrainWorker__execute.{func.__name__}"
256
+ ).remote(func, *args, **kwargs)
257
+ for w in self.workers
258
+ ]
259
+
260
+ def execute(self, func: Callable[..., T], *args, **kwargs) -> List[T]:
261
+ """Execute ``func`` on each worker and return the outputs of ``func``.
262
+
263
+ Args:
264
+ func: A function to call on each worker.
265
+ args, kwargs: Passed directly into func.
266
+
267
+ Returns:
268
+ (List[T]) A list containing the output of ``func`` from each
269
+ worker. The order is the same as ``self.workers``.
270
+
271
+ """
272
+ return ray.get(self.execute_async(func, *args, **kwargs))
273
+
274
+ def execute_single_async(
275
+ self, worker_index: int, func: Callable[..., T], *args, **kwargs
276
+ ) -> ObjectRef:
277
+ """Execute ``func`` on worker ``worker_index`` and return futures.
278
+
279
+ Args:
280
+ worker_index: The index to execute func on.
281
+ func: A function to call on the first worker.
282
+ args, kwargs: Passed directly into func.
283
+
284
+ Returns:
285
+ (ObjectRef) An ObjectRef representing the output of func.
286
+
287
+ """
288
+ if worker_index >= len(self.workers):
289
+ raise ValueError(
290
+ f"The provided worker_index {worker_index} is "
291
+ f"not valid for {self.num_workers} workers."
292
+ )
293
+ return (
294
+ self.workers[worker_index]
295
+ .actor._RayTrainWorker__execute.options(
296
+ name=f"_RayTrainWorker__execute.{func.__name__}"
297
+ )
298
+ .remote(func, *args, **kwargs)
299
+ )
300
+
301
+ def execute_single(
302
+ self, worker_index: int, func: Callable[..., T], *args, **kwargs
303
+ ) -> T:
304
+ """Execute ``func`` on worker with index ``worker_index``.
305
+
306
+ Args:
307
+ worker_index: The index to execute func on.
308
+ func: A function to call on the first worker.
309
+ args, kwargs: Passed directly into func.
310
+
311
+ Returns:
312
+ (T) The output of func.
313
+
314
+ """
315
+
316
+ return ray.get(self.execute_single_async(worker_index, func, *args, **kwargs))
317
+
318
+ def remove_workers(self, worker_indexes: List[int]):
319
+ """Removes the workers with the specified indexes.
320
+
321
+ The removed workers will go out of scope and their actor processes
322
+ will be terminated.
323
+
324
+ Args:
325
+ worker_indexes (List[int]): The indexes of the workers to remove.
326
+ """
327
+ new_workers = []
328
+ for i in range(len(self.workers)):
329
+ if i not in worker_indexes:
330
+ new_workers.append(self.workers[i])
331
+ self.workers = new_workers
332
+
333
+ def add_workers(self, num_workers: int):
334
+ """Adds ``num_workers`` to this WorkerGroup.
335
+
336
+ Note: Adding workers when the cluster/placement group is at capacity
337
+ may lead to undefined hanging behavior. If you are attempting to
338
+ replace existing workers in the WorkerGroup, remove_workers() should
339
+ be called first.
340
+
341
+ Args:
342
+ num_workers: The number of workers to add.
343
+ """
344
+ new_actors = []
345
+ new_actor_metadata = []
346
+ for _ in range(num_workers):
347
+ actor = self._remote_cls.options(
348
+ placement_group=self._placement_group
349
+ ).remote(*self._actor_cls_args, **self._actor_cls_kwargs)
350
+ new_actors.append(actor)
351
+ new_actor_metadata.append(
352
+ actor._RayTrainWorker__execute.options(
353
+ name="_RayTrainWorker__execute.construct_metadata"
354
+ ).remote(construct_metadata)
355
+ )
356
+
357
+ # Get metadata from all actors.
358
+ metadata = ray.get(new_actor_metadata)
359
+
360
+ for i in range(len(new_actors)):
361
+ self.workers.append(Worker(actor=new_actors[i], metadata=metadata[i]))
362
+
363
+ def sort_workers_by_node_id_and_gpu_id(self, _first_node_id: Optional[str] = None):
364
+ """Reorder the workers by their node id and the lowest GPU id.
365
+
366
+ This is useful for collocating workers on the same node.
367
+
368
+ Example:
369
+ Given workers with the following attributes:
370
+ worker_0: node_id=1, gpu_ids=[1]
371
+ worker_1: node_id=0, gpu_ids=[0]
372
+ worker_2: node_id=1, gpu_ids=[0]
373
+ worker_3: node_id=0, gpu_ids=[1]
374
+
375
+ The function will perform the following steps:
376
+ 1. Group by node ID:
377
+ node_id=0: worker_1, worker_3
378
+ node_id=1: worker_0, worker_2
379
+
380
+ 2. Sort each group by GPU ID:
381
+ node_id=0: worker_1 (gpu_id=0), worker_3 (gpu_id=1)
382
+ node_id=1: worker_2 (gpu_id=0), worker_0 (gpu_id=1)
383
+
384
+ Resulting in the order: [worker_1, worker_3, worker_2, worker_0]
385
+
386
+ Args:
387
+ _first_node_id: The first ID to group by.
388
+ Set this to the node ID of the trainer coordinator to ensure that the
389
+ rank 0 worker is on the same node, allowing additional resources to
390
+ be specified for rank 0 workers via
391
+ `ScalingConfig(trainer_resources=)`.
392
+ """
393
+ node_id_to_workers = defaultdict(list)
394
+
395
+ if _first_node_id is not None:
396
+ node_id_to_workers[_first_node_id] = []
397
+
398
+ for worker in self.workers:
399
+ node_id_to_workers[worker.metadata.node_id].append(worker)
400
+
401
+ # Sort workers on the same node by the lowest GPU id
402
+ # More details: https://github.com/ray-project/ray/issues/40803
403
+ def get_lowest_gpu_id(worker) -> int:
404
+ gpu_ids = worker.metadata.resource_ids.get("GPU", [])
405
+ # If there are no GPU IDs, return 0 as a default
406
+ if not gpu_ids:
407
+ return 0
408
+
409
+ # Attempt to convert GPU IDs to integers and find the minimum ID.
410
+ # Fallback to return the minimum string-based ID
411
+ try:
412
+ return min(int(gpu_id) for gpu_id in gpu_ids)
413
+ except ValueError:
414
+ return min(gpu_ids)
415
+
416
+ for node_id in node_id_to_workers:
417
+ node_id_to_workers[node_id].sort(key=get_lowest_gpu_id)
418
+
419
+ sorted_workers = []
420
+ for workers in node_id_to_workers.values():
421
+ sorted_workers.extend(workers)
422
+
423
+ self.workers = sorted_workers
424
+
425
+ def __len__(self):
426
+ return len(self.workers)
.venv/lib/python3.11/site-packages/ray/train/horovod/__init__.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # isort: off
2
+ try:
3
+ import horovod # noqa: F401
4
+ except ModuleNotFoundError:
5
+ raise ModuleNotFoundError(
6
+ "Horovod isn't installed. To install Horovod with PyTorch support, run 'pip "
7
+ "install 'horovod[pytorch]''. To install Horovod with TensorFlow support, "
8
+ "run 'pip install 'horovod[tensorflow]''."
9
+ )
10
+ # isort: on
11
+
12
+ from ray.train.horovod.config import HorovodConfig
13
+ from ray.train.horovod.horovod_trainer import HorovodTrainer
14
+ from ray.train.v2._internal.constants import is_v2_enabled
15
+
16
+ if is_v2_enabled():
17
+ from ray.train.v2.horovod.horovod_trainer import HorovodTrainer # noqa: F811
18
+
19
+ __all__ = ["HorovodConfig", "HorovodTrainer"]
20
+
21
+
22
+ # DO NOT ADD ANYTHING AFTER THIS LINE.
.venv/lib/python3.11/site-packages/ray/train/horovod/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (930 Bytes). View file
 
.venv/lib/python3.11/site-packages/ray/train/horovod/__pycache__/config.cpython-311.pyc ADDED
Binary file (9.6 kB). View file
 
.venv/lib/python3.11/site-packages/ray/train/horovod/__pycache__/horovod_trainer.cpython-311.pyc ADDED
Binary file (8.93 kB). View file
 
.venv/lib/python3.11/site-packages/ray/train/horovod/config.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dataclasses import dataclass
3
+ from typing import Optional, Set
4
+
5
+ from horovod.ray.runner import Coordinator
6
+ from horovod.ray.utils import detect_nics, nics_to_env_var
7
+ from horovod.runner.common.util import secret, timeout
8
+
9
+ import ray
10
+ from ray.train._internal.utils import update_env_vars
11
+ from ray.train._internal.worker_group import Worker, WorkerGroup
12
+ from ray.train.backend import Backend, BackendConfig
13
+ from ray.util import PublicAPI
14
+
15
+
16
+ @PublicAPI(stability="beta")
17
+ @dataclass
18
+ class HorovodConfig(BackendConfig):
19
+ """Configurations for Horovod setup.
20
+
21
+ See https://github.com/horovod/horovod/blob/master/horovod/runner/common/util/settings.py # noqa: E501
22
+
23
+ Args:
24
+ nics (Optional[Set[str]): Network interfaces that can be used for
25
+ communication.
26
+ verbose: Horovod logging verbosity.
27
+ key (Optional[str]): Secret used for communication between workers.
28
+ ssh_port (Optional[int]): Port for SSH server running on worker nodes.
29
+ ssh_identity_file (Optional[str]): Path to the identity file to
30
+ ssh into different hosts on the cluster.
31
+ ssh_str (Optional[str]): CAUTION WHEN USING THIS. Private key
32
+ file contents. Writes the private key to ssh_identity_file.
33
+ timeout_s: Timeout parameter for Gloo rendezvous.
34
+ placement_group_timeout_s: Timeout parameter for Ray
35
+ Placement Group creation. Currently unused.
36
+ """
37
+
38
+ nics: Optional[Set[str]] = None
39
+ verbose: int = 1
40
+ key: Optional[str] = None
41
+ ssh_port: Optional[int] = None
42
+ ssh_identity_file: Optional[str] = None
43
+ ssh_str: Optional[str] = None
44
+ timeout_s: int = 300
45
+ placement_group_timeout_s: int = 100
46
+
47
+ @property
48
+ def start_timeout(self):
49
+ return timeout.Timeout(
50
+ self.timeout_s,
51
+ message="Timed out waiting for {activity}. Please "
52
+ "check connectivity between servers. You "
53
+ "may need to increase the --start-timeout "
54
+ "parameter if you have too many servers.",
55
+ )
56
+
57
+ def __post_init__(self):
58
+ if self.ssh_str and not os.path.exists(self.ssh_identity_file):
59
+ with open(self.ssh_identity_file, "w") as f:
60
+ os.chmod(self.ssh_identity_file, 0o600)
61
+ f.write(self.ssh_str)
62
+
63
+ if self.key is None:
64
+ self.key = secret.make_secret_key()
65
+
66
+ @property
67
+ def backend_cls(self):
68
+ return _HorovodBackend
69
+
70
+
71
+ class _HorovodBackend(Backend):
72
+ share_cuda_visible_devices: bool = True
73
+
74
+ def on_start(self, worker_group: WorkerGroup, backend_config: HorovodConfig):
75
+ # TODO(matt): Implement placement group strategies in BackendExecutor.
76
+
77
+ # Initialize workers with Horovod environment variables
78
+ setup_futures = []
79
+ for rank in range(len(worker_group)):
80
+ worker_node_id = worker_group.workers[rank].metadata.node_id
81
+ setup_futures.append(
82
+ worker_group.execute_single_async(
83
+ rank,
84
+ _init_env_vars,
85
+ rank,
86
+ len(worker_group),
87
+ worker_node_id,
88
+ )
89
+ )
90
+ ray.get(setup_futures)
91
+
92
+ # Use Horovod Ray Coordinator
93
+ # backend_config as settings
94
+ self.coordinator = Coordinator(backend_config)
95
+
96
+ # Get all the hostnames of all workers
97
+ node_ids = [w.metadata.node_id for w in worker_group.workers]
98
+ hostnames = [w.metadata.hostname for w in worker_group.workers]
99
+ # Register each hostname to the coordinator. assumes the hostname
100
+ # ordering is the same.
101
+ for rank, (hostname, node_id) in enumerate(zip(hostnames, node_ids)):
102
+ self.coordinator.register(hostname, node_id, rank)
103
+ all_info = self.coordinator.finalize_registration()
104
+
105
+ setup_futures = []
106
+ for rank, local_cross_env_var in all_info.items():
107
+ setup_futures.append(
108
+ worker_group.execute_single_async(
109
+ rank, update_env_vars, local_cross_env_var
110
+ )
111
+ )
112
+ ray.get(setup_futures)
113
+
114
+ coordinator_envs = self.coordinator.establish_rendezvous()
115
+
116
+ # Get one worker from each host/node.
117
+ node_worker_indexes = [node_ids.index(node_id) for node_id in set(node_ids)]
118
+ node_workers = [
119
+ _HorovodWorkerWrapper(worker_group.workers[worker_index])
120
+ for worker_index in node_worker_indexes
121
+ ]
122
+ assert len(node_workers) == len(self.coordinator.hostnames)
123
+
124
+ nics = detect_nics(
125
+ backend_config,
126
+ all_host_names=list(self.coordinator.hostnames),
127
+ node_workers=node_workers,
128
+ )
129
+ coordinator_envs.update(nics_to_env_var(nics))
130
+
131
+ worker_group.execute(update_env_vars, coordinator_envs)
132
+
133
+
134
+ def _init_env_vars(world_rank: int, world_size: int, node_id: str):
135
+ """Initialize Horovod environment variables."""
136
+ os.environ["HOROVOD_HOSTNAME"] = node_id
137
+ os.environ["HOROVOD_RANK"] = str(world_rank)
138
+ os.environ["HOROVOD_SIZE"] = str(world_size)
139
+
140
+
141
+ # TODO(tgaddair): temporary workaround for Horovod's worker discovery logic,
142
+ # which requires passing in an extra parameter as part of the RayExecutor
143
+ # API. This will be removed in the future as we migrate more of the
144
+ # RayExecutor utils into Ray Train.
145
+ # See: https://github.com/horovod/horovod/blob/v0.23.0/horovod/ray/driver_service.py#L9 # noqa: E501
146
+ @dataclass
147
+ class _HorovodWorkerWrapper:
148
+ w: Worker
149
+
150
+ @property
151
+ def execute(self):
152
+ w = self.w
153
+
154
+ class ExecuteHandle:
155
+ def remote(self, func, *args, **kwargs):
156
+ _ = None
157
+ return w.actor._RayTrainWorker__execute.remote(func, _, *args, **kwargs)
158
+
159
+ return ExecuteHandle()
.venv/lib/python3.11/site-packages/ray/train/horovod/horovod_trainer.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Callable, Dict, Optional, Union
2
+
3
+ from ray.air.config import RunConfig, ScalingConfig
4
+ from ray.train import Checkpoint, DataConfig
5
+ from ray.train.data_parallel_trainer import DataParallelTrainer
6
+ from ray.train.horovod.config import HorovodConfig
7
+ from ray.train.trainer import GenDataset
8
+ from ray.util.annotations import PublicAPI
9
+
10
+
11
+ @PublicAPI(stability="beta")
12
+ class HorovodTrainer(DataParallelTrainer):
13
+ """A Trainer for data parallel Horovod training.
14
+
15
+ This Trainer runs the function ``train_loop_per_worker`` on multiple Ray
16
+ Actors. These actors already have the necessary Horovod setup already
17
+ configured for distributed Horovod training.
18
+
19
+ The ``train_loop_per_worker`` function is expected to take in either 0 or 1
20
+ arguments:
21
+
22
+ .. testcode::
23
+
24
+ def train_loop_per_worker():
25
+ ...
26
+
27
+ .. testcode::
28
+
29
+ def train_loop_per_worker(config: Dict):
30
+ ...
31
+
32
+ If ``train_loop_per_worker`` accepts an argument, then
33
+ ``train_loop_config`` will be passed in as the argument. This is useful if you
34
+ want to tune the values in ``train_loop_config`` as hyperparameters.
35
+
36
+ If the ``datasets`` dict contains a training dataset (denoted by
37
+ the "train" key), then it will be split into multiple dataset
38
+ shards that can then be accessed by ``ray.train.get_dataset_shard("train")`` inside
39
+ ``train_loop_per_worker``. All the other datasets will not be split and
40
+ ``ray.train.get_dataset_shard(...)`` will return the the entire Dataset.
41
+
42
+ Inside the ``train_loop_per_worker`` function, you can use any of the
43
+ :ref:`Ray Train loop methods <train-loop-api>`.
44
+
45
+ .. testcode::
46
+
47
+ from ray import train
48
+
49
+ def train_loop_per_worker():
50
+ # Report intermediate results for callbacks or logging and
51
+ # checkpoint data.
52
+ train.report(...)
53
+
54
+ # Returns dict of last saved checkpoint.
55
+ train.get_checkpoint()
56
+
57
+ # Returns the Dataset shard for the given key.
58
+ train.get_dataset_shard("my_dataset")
59
+
60
+ # Returns the total number of workers executing training.
61
+ train.get_context().get_world_size()
62
+
63
+ # Returns the rank of this worker.
64
+ train.get_context().get_world_rank()
65
+
66
+ # Returns the rank of the worker on the current node.
67
+ train.get_context().get_local_rank()
68
+
69
+ Any returns from the ``train_loop_per_worker`` will be discarded and not
70
+ used or persisted anywhere.
71
+
72
+ You could use ``TensorflowPredictor`` or ``TorchPredictor`` in conjunction with
73
+ HorovodTrainer. You must save the model under the "model" kwarg in the
74
+ ``Checkpoint`` passed to ``train.report()``, so that it can be used by
75
+ corresponding predictors.
76
+
77
+ Example:
78
+
79
+
80
+ .. testcode::
81
+ :skipif: True
82
+
83
+ import os
84
+ import tempfile
85
+
86
+ import ray
87
+ import horovod.torch as hvd
88
+ import torch
89
+ import torch.nn as nn
90
+
91
+ from ray import train
92
+ import ray.train.torch # Need this to use `train.torch.get_device()`
93
+ from ray.train import Checkpoint, ScalingConfig
94
+ from ray.train.horovod import HorovodTrainer
95
+
96
+ # If using GPUs, set this to True.
97
+ use_gpu = False
98
+
99
+ input_size = 1
100
+ layer_size = 15
101
+ output_size = 1
102
+ num_epochs = 3
103
+
104
+ class NeuralNetwork(nn.Module):
105
+ def __init__(self):
106
+ super(NeuralNetwork, self).__init__()
107
+ self.layer1 = nn.Linear(input_size, layer_size)
108
+ self.relu = nn.ReLU()
109
+ self.layer2 = nn.Linear(layer_size, output_size)
110
+ def forward(self, input):
111
+ return self.layer2(self.relu(self.layer1(input)))
112
+
113
+ def train_loop_per_worker():
114
+ hvd.init()
115
+ dataset_shard = train.get_dataset_shard("train")
116
+ model = NeuralNetwork()
117
+ device = train.torch.get_device()
118
+ model.to(device)
119
+ loss_fn = nn.MSELoss()
120
+ lr_scaler = 1
121
+ optimizer = torch.optim.SGD(model.parameters(), lr=0.1 * lr_scaler)
122
+ # Horovod: wrap optimizer with DistributedOptimizer.
123
+ optimizer = hvd.DistributedOptimizer(
124
+ optimizer,
125
+ named_parameters=model.named_parameters(),
126
+ op=hvd.Average,
127
+ )
128
+ for epoch in range(num_epochs):
129
+ model.train()
130
+ for batch in dataset_shard.iter_torch_batches(
131
+ batch_size=32, dtypes=torch.float
132
+ ):
133
+ inputs, labels = torch.unsqueeze(batch["x"], 1), batch["y"]
134
+ outputs = model(inputs)
135
+ loss = loss_fn(outputs, labels)
136
+ optimizer.zero_grad()
137
+ loss.backward()
138
+ optimizer.step()
139
+ print(f"epoch: {epoch}, loss: {loss.item()}")
140
+
141
+ # Save a model checkpoint at the end of each epoch
142
+ with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
143
+ ckpt_path = os.path.join(temp_checkpoint_dir, "model.pt")
144
+ torch.save(model.state_dict(), ckpt_path)
145
+ train.report(
146
+ {"loss": loss.item(), "epoch": epoch},
147
+ checkpoint=Checkpoint.from_directory(temp_checkpoint_dir),
148
+ )
149
+
150
+ train_dataset = ray.data.from_items([{"x": x, "y": x + 1} for x in range(32)])
151
+ scaling_config = ScalingConfig(num_workers=3, use_gpu=use_gpu)
152
+ trainer = HorovodTrainer(
153
+ train_loop_per_worker=train_loop_per_worker,
154
+ scaling_config=scaling_config,
155
+ datasets={"train": train_dataset},
156
+ )
157
+ result = trainer.fit()
158
+
159
+ Args:
160
+ train_loop_per_worker: The training function to execute.
161
+ This can either take in no arguments or a ``config`` dict.
162
+ train_loop_config: Configurations to pass into
163
+ ``train_loop_per_worker`` if it accepts an argument.
164
+ horovod_config: Configuration for setting up the Horovod backend.
165
+ If set to None, use the default configuration. This replaces the
166
+ ``backend_config`` arg of ``DataParallelTrainer``.
167
+ scaling_config: Configuration for how to scale data parallel training.
168
+ dataset_config: Configuration for dataset ingest.
169
+ run_config: Configuration for the execution of the training run.
170
+ datasets: Any Datasets to use for training. Use
171
+ the key "train" to denote which dataset is the training
172
+ dataset.
173
+ resume_from_checkpoint: A checkpoint to resume training from.
174
+ metadata: Dict that should be made available via
175
+ `ray.train.get_context().get_metadata()` and in `checkpoint.get_metadata()`
176
+ for checkpoints saved from this Trainer. Must be JSON-serializable.
177
+ """
178
+
179
+ def __init__(
180
+ self,
181
+ train_loop_per_worker: Union[Callable[[], None], Callable[[Dict], None]],
182
+ *,
183
+ train_loop_config: Optional[Dict] = None,
184
+ horovod_config: Optional[HorovodConfig] = None,
185
+ scaling_config: Optional[ScalingConfig] = None,
186
+ dataset_config: Optional[DataConfig] = None,
187
+ run_config: Optional[RunConfig] = None,
188
+ datasets: Optional[Dict[str, GenDataset]] = None,
189
+ metadata: Optional[Dict[str, Any]] = None,
190
+ resume_from_checkpoint: Optional[Checkpoint] = None,
191
+ ):
192
+ super().__init__(
193
+ train_loop_per_worker=train_loop_per_worker,
194
+ train_loop_config=train_loop_config,
195
+ backend_config=horovod_config or HorovodConfig(),
196
+ scaling_config=scaling_config,
197
+ dataset_config=dataset_config,
198
+ run_config=run_config,
199
+ datasets=datasets,
200
+ resume_from_checkpoint=resume_from_checkpoint,
201
+ metadata=metadata,
202
+ )
.venv/lib/python3.11/site-packages/ray/train/lightning/__init__.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # isort: off
2
+ try:
3
+ import lightning # noqa: F401
4
+ except ModuleNotFoundError:
5
+ try:
6
+ import pytorch_lightning # noqa: F401
7
+ except ModuleNotFoundError:
8
+ raise ModuleNotFoundError(
9
+ "PyTorch Lightning isn't installed. To install PyTorch Lightning, "
10
+ "please run 'pip install lightning'"
11
+ )
12
+ # isort: on
13
+
14
+ from ray.train.lightning._lightning_utils import (
15
+ RayDDPStrategy,
16
+ RayDeepSpeedStrategy,
17
+ RayFSDPStrategy,
18
+ RayLightningEnvironment,
19
+ RayTrainReportCallback,
20
+ prepare_trainer,
21
+ )
22
+ from ray.train.v2._internal.constants import is_v2_enabled
23
+
24
+ if is_v2_enabled():
25
+ from ray.train.v2.lightning.lightning_utils import ( # noqa: F811
26
+ RayTrainReportCallback,
27
+ )
28
+
29
+ __all__ = [
30
+ "prepare_trainer",
31
+ "RayDDPStrategy",
32
+ "RayFSDPStrategy",
33
+ "RayDeepSpeedStrategy",
34
+ "RayLightningEnvironment",
35
+ "RayTrainReportCallback",
36
+ ]
37
+
38
+
39
+ # DO NOT ADD ANYTHING AFTER THIS LINE.
.venv/lib/python3.11/site-packages/ray/train/lightning/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (1.16 kB). View file