diff --git a/.venv/lib/python3.11/site-packages/ray/train/v2/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/train/v2/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d3e0140cfcd835880b6beba0d27fca6266ab739c Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/train/v2/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/__init__.py b/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8345c12d83c51dd8e91eb9d08ba1162278c600d6 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/__pycache__/exceptions.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/__pycache__/exceptions.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f0ccf483494aa40a798bbcfbd04f33f463b024de Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/__pycache__/exceptions.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/__pycache__/util.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/__pycache__/util.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..941ec39c856449da8f0783a440ef68d2094d16cb Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/__pycache__/util.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/callbacks/__init__.py b/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/callbacks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b46f144717580c49f3632578f2d1f4d1ef3080d6 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/callbacks/__init__.py @@ -0,0 +1,14 @@ +from .accelerators import AcceleratorSetupCallback +from .backend_setup import BackendSetupCallback +from .datasets import DatasetsSetupCallback +from .working_dir_setup import WorkingDirectorySetupCallback + +__all__ = [ + "AcceleratorSetupCallback", + "BackendSetupCallback", + "DatasetsSetupCallback", + "WorkingDirectorySetupCallback", +] + + +# DO NOT ADD ANYTHING AFTER THIS LINE. diff --git a/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/callbacks/__pycache__/accelerators.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/callbacks/__pycache__/accelerators.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3769d06f050fa14455df8bf79c0ffea9bf6cedf9 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/callbacks/__pycache__/accelerators.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/callbacks/__pycache__/backend_setup.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/callbacks/__pycache__/backend_setup.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0abe3891611d4a2348f9eb68ced0a5374e5c9146 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/callbacks/__pycache__/backend_setup.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/callbacks/__pycache__/datasets.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/callbacks/__pycache__/datasets.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b33be2cd82afbe3052c4ad586ec1f2e093893955 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/callbacks/__pycache__/datasets.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/callbacks/__pycache__/metrics.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/callbacks/__pycache__/metrics.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..68c58324526855e78f8c766b9826158cde4281bb Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/callbacks/__pycache__/metrics.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/callbacks/__pycache__/user_callback.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/callbacks/__pycache__/user_callback.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aa43ae703c14eb464681af8df0a191fd3891b052 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/callbacks/__pycache__/user_callback.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/callbacks/__pycache__/working_dir_setup.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/callbacks/__pycache__/working_dir_setup.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..117a0eaafa22f9e99e884a2a865ec6bcb42e3761 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/callbacks/__pycache__/working_dir_setup.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/callbacks/accelerators.py b/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/callbacks/accelerators.py new file mode 100644 index 0000000000000000000000000000000000000000..4e0eee326aadc06102fc7d79d221a9e8a51e4ee4 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/callbacks/accelerators.py @@ -0,0 +1,151 @@ +import logging +import os +from collections import defaultdict +from typing import List + +import ray._private.ray_constants as ray_constants +from ray._private.ray_constants import env_bool +from ray.train import BackendConfig +from ray.train.constants import ENABLE_SHARE_CUDA_VISIBLE_DEVICES_ENV +from ray.train.v2._internal.execution.callback import WorkerGroupCallback +from ray.train.v2._internal.execution.worker_group import ActorMetadata, WorkerGroup +from ray.train.v2._internal.util import ray_get_safe +from ray.train.v2.api.config import ScalingConfig + +logger = logging.getLogger(__name__) + + +class AcceleratorSetupCallback(WorkerGroupCallback): + """Perform accelerator setup for workers. + + For example, this callback can be used to share CUDA_VISIBLE_DEVICES + among workers on the same node. + """ + + def __init__(self, backend_config: BackendConfig, scaling_config: ScalingConfig): + self._backend = backend_config.backend_cls() + self._scaling_config = scaling_config + + def after_worker_group_start(self, worker_group: WorkerGroup): + self._maybe_share_cuda_visible_devices(worker_group) + # TODO: Add support for sharing other accelerator resources. + + def _maybe_share_cuda_visible_devices(self, worker_group: WorkerGroup): + share_cuda_visible_devices_enabled = env_bool( + ENABLE_SHARE_CUDA_VISIBLE_DEVICES_ENV, + self._backend.share_cuda_visible_devices, + ) + + if ( + self._scaling_config._resources_per_worker_not_none.get("GPU", 0) > 0 + and share_cuda_visible_devices_enabled + ): + _share_cuda_visible_devices(worker_group) + + +def _share_cuda_visible_devices(worker_group: WorkerGroup): + """Sets CUDA_VISIBLE_DEVICES on all workers. + For each worker, CUDA_VISIBLE_DEVICES will be set to the GPU IDs + visible to all workers on that worker's node. + This allows GPU workers on the same node to communicate with one + another. + + Example: + Setup: + - Node1: + - Worker1: {0, 1} + - Worker2: {2, 3} + - Node2: + - Worker3: {0, 1} + CUDA_VISIBLE_DEVICES: + - Worker1: "0,1,2,3" + - Worker2: "0,1,2,3" + - Worker2: "0,1" + """ + _share_accelerator_ids( + worker_group, ray_constants.GPU, ray_constants.CUDA_VISIBLE_DEVICES_ENV_VAR + ) + + +def _share_accelerator_ids( + worker_group: WorkerGroup, accelerator_name: str, env_var: str +): + """Sets the given env_var on all workers. + For each worker, the cores/devices are visible to all the + workers on that worker's node. This allows workers on the + same node to communicate with one another. + + Example: + Setup: + - Node1: + - Worker1: {0, 1} + - Worker2: {2, 3} + - Node2: + - Worker3: {0, 1} + NEURON_RT_VISIBLE_CORES/TPU_VISIBLE_CHIPS/...: + - Worker1: "0,1,2,3" + - Worker2: "0,1,2,3" + - Worker2: "0,1" + + Args: + accelerator_name: The name of the accelerator. + env_var: The name of the environment variable to set. + """ + if not worker_group.has_started(): + raise RuntimeError( + "WorkerGroup must be started before sharing accelerator IDs." + ) + + worker_metadatas = [worker.metadata for worker in worker_group.get_workers()] + visible_accelerator_ids_per_worker = _get_visible_accelerator_ids_per_worker( + worker_metadatas=worker_metadatas, accelerator_name=accelerator_name + ) + + def set_accelerator_ids(accelerator_ids): + os.environ[env_var] = accelerator_ids + + futures = [] + for rank, visible_accelerator_ids in enumerate(visible_accelerator_ids_per_worker): + futures.append( + worker_group.execute_single_async( + rank, set_accelerator_ids, accelerator_ids=visible_accelerator_ids + ) + ) + ray_get_safe(futures) + + +def _get_visible_accelerator_ids_per_worker( + worker_metadatas: List[ActorMetadata], accelerator_name: str +) -> List[str]: + """Returns a list of comma-separated accelerator IDs visible to each worker. + + All workers on a node should have the same set of visible accelerators, + which is the union of accelerator ids of the workers. + + Returns: + visible_accelerator_ids_per_worker: A list of comma-separated accelerator ID + strings. This list is the same length as the number of workers. + + """ + for metadata in worker_metadatas: + if accelerator_name not in metadata.accelerator_ids: + raise ValueError( + f"Accelerator '{accelerator_name}' is not available on all workers. " + f"Got these available accelerators instead: {metadata.accelerator_ids}" + ) + + node_id_to_accelerator_ids = defaultdict(set) + + for metadata in worker_metadatas: + node_id_to_accelerator_ids[metadata.node_id].update( + metadata.accelerator_ids[accelerator_name] + ) + + visible_accelerator_ids_per_worker = [] + for worker_id in range(len(worker_metadatas)): + node_id = worker_metadatas[worker_id].node_id + accelerator_ids = sorted(node_id_to_accelerator_ids[node_id]) + all_resource_ids = ",".join([str(id) for id in accelerator_ids]) + visible_accelerator_ids_per_worker.append(all_resource_ids) + + return visible_accelerator_ids_per_worker diff --git a/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/callbacks/backend_setup.py b/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/callbacks/backend_setup.py new file mode 100644 index 0000000000000000000000000000000000000000..8b592980f7e5c201f7ff541a0f592b7f1f54428f --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/callbacks/backend_setup.py @@ -0,0 +1,27 @@ +import logging + +from ray.exceptions import RayActorError +from ray.train.backend import BackendConfig +from ray.train.v2._internal.execution.callback import WorkerGroupCallback +from ray.train.v2._internal.execution.worker_group import WorkerGroup + +logger = logging.getLogger(__name__) + + +class BackendSetupCallback(WorkerGroupCallback): + def __init__(self, backend_config: BackendConfig): + self._backend_config = backend_config + self._backend = backend_config.backend_cls() + + def after_worker_group_start(self, worker_group: WorkerGroup): + self._backend.on_start(worker_group, self._backend_config) + self._backend.on_training_start(worker_group, self._backend_config) + + def before_worker_group_shutdown(self, worker_group: WorkerGroup): + try: + self._backend.on_shutdown(worker_group, self._backend_config) + except RayActorError: + logger.warning( + "Graceful shutdown of backend failed. This is " + "expected if one of the workers has crashed." + ) diff --git a/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/callbacks/datasets.py b/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/callbacks/datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..80536b3248c2c1f9378f817dae79787d63047cac --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/callbacks/datasets.py @@ -0,0 +1,76 @@ +import copy +from typing import Any, Callable, Dict, List, Union + +import ray.train +from ray.data import Dataset +from ray.data.context import DataContext +from ray.train.v2._internal.execution.callback import WorkerGroupCallback +from ray.train.v2._internal.execution.worker_group.worker_group import WorkerGroup + +# A type representing either a ray.data.Dataset or a function that returns a +# ray.data.Dataset and accepts no arguments. +GenDataset = Union[Dataset, Callable[[], Dataset]] + + +class DatasetsSetupCallback(WorkerGroupCallback): + """The callback to setup Ray Datasets for the worker group.""" + + def __init__( + self, + datasets: Dict[str, GenDataset], + data_config: ray.train.DataConfig, + scaling_config: ray.train.ScalingConfig, + ): + self._datasets = datasets + self._data_config = data_config + self._scaling_config = scaling_config + + # Capture the current DataContext to propagate it to + # the Train workers later. + # The propagation works in the following way: + # 1. This callback is created when user create the Trainer. + # 2. Then this callback will be passed to the Controller actor. + # 3. Lastly, when the worker group is initialized, the Controller + # will call the `after_worker_group_start` callback to propagate + # the DataContext to Train workers. + self._data_context = copy.deepcopy(DataContext.get_current()) + + def get_train_total_resources( + self, scaling_config: ray.train.ScalingConfig + ) -> Dict[str, float]: + """Return the resources reserved for training, so that Data can exclude + these resources logically from its available pool.""" + return scaling_config.total_resources + + def before_init_train_context( + self, worker_group: "WorkerGroup" + ) -> Dict[str, List[Any]]: + # Configure dataset shards + datasets = {k: v() if callable(v) else v for k, v in self._datasets.items()} + node_ids = [worker.metadata.node_id for worker in worker_group.get_workers()] + + # Notify the DataConfig about the total resources reserved for training. + total_train_resources = self.get_train_total_resources(self._scaling_config) + self._data_config.set_train_total_resources( + total_train_resources.get("CPU", 0), total_train_resources.get("GPU", 0) + ) + + dataset_shards = self._data_config.configure( + datasets, + world_size=len(worker_group), + worker_handles=None, + worker_node_ids=node_ids, + ) + assert len(dataset_shards) == len(worker_group) + + return {"dataset_shards": dataset_shards} + + def after_worker_group_start(self, worker_group: "WorkerGroup"): + # Propagate DataContext + def _propagate_data_context(ctx: DataContext): + DataContext._set_current(ctx) + + worker_group.execute( + _propagate_data_context, + self._data_context, + ) diff --git a/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/callbacks/metrics.py b/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/callbacks/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..1bd5a7c48336ee4e13c81163b896ecaee9123e76 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/callbacks/metrics.py @@ -0,0 +1,250 @@ +import threading +import time +from contextlib import contextmanager +from dataclasses import asdict, dataclass, field, fields +from typing import Dict, Optional + +from ray.train.v2._internal.execution.callback import ( + ControllerCallback, + TrainContextCallback, + WorkerCallback, + WorkerGroupCallback, +) +from ray.train.v2._internal.execution.context import TrainRunContext, get_train_context +from ray.train.v2._internal.util import time_monotonic +from ray.util.metrics import Gauge + +# Prometheus Tag keys for the worker and controller metrics. +RUN_NAME_TAG_KEY = "ray_train_run_name" +WORKER_WORLD_RANK_TAG_KEY = "ray_train_worker_world_rank" + + +@dataclass +class ControllerMetrics: + """A list of Train controller metrics. + + Metric metadata attributes: + - description (required): A human-readable description of the metric, also used as + the chart description on the Ray Train dashboard. + """ + + train_worker_group_start_total_time_s: float = field( + default=0.0, + metadata={ + "description": ( + "Cumulative time in seconds to start worker groups in the Train job." + ), + }, + ) + + train_worker_group_shutdown_total_time_s: float = field( + default=0.0, + metadata={ + "description": ( + "Cumulative time in seconds to shutdown worker groups in the Train job." + ), + }, + ) + + +@dataclass +class WorkerMetrics: + """A list of Train worker metrics. + + Metric metadata attributes: + - description (required): A human-readable description of the metric, also used as + the chart description on the Ray Train dashboard. + """ + + train_report_total_blocked_time_s: float = field( + default=0.0, + metadata={ + "description": ( + "Cumulative time in seconds to report a checkpoint to the storage." + ), + }, + ) + + +class ControllerMetricsCallback(ControllerCallback, WorkerGroupCallback): + # Interval for pushing metrics to Prometheus. + LOCAL_METRICS_PUSH_INTERVAL_S: float = 5.0 + CONTROLLER_TAG_KEYS = (RUN_NAME_TAG_KEY,) + + def __init__(self, train_run_context: TrainRunContext): + """ + This callback is initialized on the driver process and then passed to the + controller. This callback collects metrics from the controller actor as well + as the metrics related to the worker groups. + """ + self._run_name = train_run_context.get_run_config().name + self._thread: Optional[threading.Thread] = None + self._thread_stop_event: Optional[threading.Event] = None + self._metrics: Optional[ControllerMetrics] = None + self._metrics_lock: Optional[threading.Lock] = None + self._controller_tag: Dict[str, str] = {} + self._metrics_gauges: Dict[str, Gauge] = {} + + def _create_prometheus_controller_metrics(self) -> Dict[str, Gauge]: + """Create Prometheus worker metrics for the ControllerMetrics dataclass.""" + metrics = {} + for _field in fields(ControllerMetrics): + metric_description = _field.metadata.get("description") + metrics[_field.name] = Gauge( + _field.name, + description=metric_description, + tag_keys=self.CONTROLLER_TAG_KEYS, + ) + return metrics + + def after_controller_start(self): + """ + Creating a thread to periodically push local metrics to the gauges + after the train controller starts. + """ + self._controller_tag = { + RUN_NAME_TAG_KEY: self._run_name, + } + self._thread_stop_event = threading.Event() + self._metrics_lock = threading.Lock() + self._metrics = ControllerMetrics() + self._metrics_gauges = self._create_prometheus_controller_metrics() + + def push_local_metrics(): + while not self._thread_stop_event.is_set(): + with self._metrics_lock: + metrics_dict = asdict(self._metrics) + for metric_name, metric_value in metrics_dict.items(): + self._metrics_gauges[metric_name].set( + metric_value, self._controller_tag + ) + time.sleep(ControllerMetricsCallback.LOCAL_METRICS_PUSH_INTERVAL_S) + + assert not self._thread + self._thread = threading.Thread(target=push_local_metrics, daemon=True) + self._thread.start() + + def before_controller_shutdown(self): + """ + Stop the thread that pushes local metrics to the gauges before the + controller shuts down. + """ + # Stop the thread that pushes local metrics to the metrics gauges. + assert not self._thread_stop_event.is_set() + self._thread_stop_event.set() + # Reset the metrics to their default values. + for _field in fields(self._metrics): + self._metrics_gauges[_field.name].set(_field.default, self._controller_tag) + + @contextmanager + def on_worker_group_start(self): + """ + Context manager to measure the time taken to start a worker group. + """ + start_time_s = time_monotonic() + yield + elapsed_time_s = time_monotonic() - start_time_s + with self._metrics_lock: + self._metrics.train_worker_group_start_total_time_s += elapsed_time_s + + @contextmanager + def on_worker_group_shutdown(self): + """ + Context manager to measure the time taken to start a worker group. + """ + start_time_s = time_monotonic() + yield + elapsed_time_s = time_monotonic() - start_time_s + with self._metrics_lock: + self._metrics.train_worker_group_shutdown_total_time_s += elapsed_time_s + + +class WorkerMetricsCallback(WorkerCallback, TrainContextCallback): + # Interval for pushing metrics to Prometheus. + LOCAL_METRICS_PUSH_INTERVAL_S: float = 5.0 + WORKER_TAG_KEYS = (RUN_NAME_TAG_KEY, WORKER_WORLD_RANK_TAG_KEY) + + def __init__(self, train_run_context: TrainRunContext): + """ + This callback is initialized on the driver process and then passed to the + workers. When adding more class attributes, make sure the attributes are + serializable picklable. + + TODO: Making Callbacks factory methods that when they are initialized on the + driver process, we do not need to worry about pickling the callback instances. + """ + self._run_name = train_run_context.get_run_config().name + self._thread: Optional[threading.Thread] = None + self._thread_stop_event: Optional[threading.Event] = None + self._metrics_lock: Optional[threading.Lock] = None + self._metrics: Optional[WorkerMetrics] = None + self._worker_tag: Dict[str, str] = {} + self._metrics_gauges: Dict[str, Gauge] = {} + + def _create_prometheus_worker_metrics(self) -> Dict[str, Gauge]: + """Create Prometheus worker metrics for the TrainMetrics dataclass.""" + metrics = {} + for _field in fields(self._metrics): + metric_description = _field.metadata.get("description") + metrics[_field.name] = Gauge( + _field.name, + description=metric_description, + tag_keys=self.WORKER_TAG_KEYS, + ) + return metrics + + def after_init_train_context(self): + """ + Creating a thread to periodically push local metrics to the gauges + after the train context is initialized. + + Note: + This method should be called after the train context is initialized on + each of the worker. The thread should not be created in the `__init__` + method which is called on the train driver process. + """ + self._worker_tag = { + RUN_NAME_TAG_KEY: self._run_name, + WORKER_WORLD_RANK_TAG_KEY: str(get_train_context().get_world_rank()), + } + self._thread_stop_event = threading.Event() + self._metrics_lock = threading.Lock() + self._metrics = WorkerMetrics() + self._metrics_gauges = self._create_prometheus_worker_metrics() + + def push_local_metrics(): + while not self._thread_stop_event.is_set(): + with self._metrics_lock: + metrics_dict = asdict(self._metrics) + for metric_name, metric_value in metrics_dict.items(): + self._metrics_gauges[metric_name].set( + metric_value, self._worker_tag + ) + time.sleep(WorkerMetricsCallback.LOCAL_METRICS_PUSH_INTERVAL_S) + + assert not self._thread + self._thread = threading.Thread(target=push_local_metrics, daemon=True) + self._thread.start() + + def before_worker_shutdown(self): + """ + Stop the thread that pushes local metrics to the metrics gauges before + the worker group shuts down. + """ + # Stop the thread that pushes local metrics to the gauges. + assert not self._thread_stop_event.is_set() + self._thread_stop_event.set() + # Reset the metrics to their default values. + for _field in fields(self._metrics): + self._metrics_gauges[_field.name].set(_field.default, self._worker_tag) + + @contextmanager + def on_report(self): + """ + Context manager to measure the time taken to report a checkpoint to the storage. + """ + start_time_s = time_monotonic() + yield + elapsed_time_s = time_monotonic() - start_time_s + with self._metrics_lock: + self._metrics.train_report_total_blocked_time_s += elapsed_time_s diff --git a/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/callbacks/user_callback.py b/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/callbacks/user_callback.py new file mode 100644 index 0000000000000000000000000000000000000000..b046f965690a9e9abc565de93d50c54854e3b4bb --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/callbacks/user_callback.py @@ -0,0 +1,50 @@ +from typing import Any, Dict, List, Optional + +from ray.train import Checkpoint +from ray.train.v2._internal.execution.callback import ( + ReportCallback, + WorkerGroupCallback, +) +from ray.train.v2._internal.execution.context import TrainRunContext +from ray.train.v2._internal.execution.worker_group import WorkerGroupStatus +from ray.train.v2.api.callback import UserCallback + + +class UserCallbackHandler(WorkerGroupCallback, ReportCallback): + """Responsible for calling methods of subscribers implementing + the `UserCallback` interface. + """ + + def __init__( + self, user_callbacks: List[UserCallback], train_run_context: TrainRunContext + ): + self._user_callbacks = user_callbacks + self._train_run_context = train_run_context + + # -------------------------- + # ReportCallback + # -------------------------- + + def after_report( + self, metrics: List[Dict[str, Any]], checkpoint: Optional[Checkpoint] + ): + for user_callback in self._user_callbacks: + user_callback.after_report( + run_context=self._train_run_context, + metrics=metrics, + checkpoint=checkpoint, + ) + + # -------------------------- + # WorkerGroupCallback + # -------------------------- + + def after_worker_group_poll_status(self, worker_group_status: WorkerGroupStatus): + if not worker_group_status.errors: + return + + for user_callback in self._user_callbacks: + user_callback.after_exception( + run_context=self._train_run_context, + worker_exceptions=worker_group_status.errors, + ) diff --git a/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/callbacks/working_dir_setup.py b/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/callbacks/working_dir_setup.py new file mode 100644 index 0000000000000000000000000000000000000000..f43ab2e3094dd54a890aa60b00a7d490ffb7d1ad --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/callbacks/working_dir_setup.py @@ -0,0 +1,24 @@ +import logging +import os + +from ray.train.v2._internal.execution.callback import WorkerGroupCallback +from ray.train.v2._internal.execution.context import get_train_context +from ray.train.v2._internal.execution.worker_group import WorkerGroup + +logger = logging.getLogger(__name__) + + +class WorkingDirectorySetupCallback(WorkerGroupCallback): + def after_worker_group_start(self, worker_group: WorkerGroup): + def chdir_to_working_dir() -> None: + """Create the local working directory for the experiment.""" + local_working_directory = ( + get_train_context().get_storage().local_working_directory + ) + os.makedirs(local_working_directory, exist_ok=True) + logger.debug( + f"Changing the working directory to: {local_working_directory}" + ) + os.chdir(local_working_directory) + + worker_group.execute(chdir_to_working_dir) diff --git a/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/constants.py b/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..347af27b4a450d0e6b9b62fce08c47c2c2cf47c7 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/constants.py @@ -0,0 +1,84 @@ +import os +from typing import Dict + +from ray._private.ray_constants import env_bool, env_set_by_user + +# Unsupported configs can use this value to detect if the user has set it. +_UNSUPPORTED = "UNSUPPORTED" +_DEPRECATED = "DEPRECATED" + +# The name of the file that is used to validate the storage. +VALIDATE_STORAGE_MARKER_FILENAME = ".validate_storage_marker" +# The name of the file that is used to store the checkpoint manager snapshot. +CHECKPOINT_MANAGER_SNAPSHOT_FILENAME = "checkpoint_manager_snapshot.json" + + +# ===================== +# Environment Variables +# ===================== + +# Polling interval for the Train controller. +# This determines how many seconds the controller will wait between +# polling the worker group for its status. +HEALTH_CHECK_INTERVAL_S_ENV_VAR = "RAY_TRAIN_HEALTH_CHECK_INTERVAL_S" +DEFAULT_HEALTH_CHECK_INTERVAL_S: float = 2.0 + +# The time in seconds a worker health check must be hanging for +# before the controller marks the worker as dead and handles the failure. +WORKER_HEALTH_CHECK_TIMEOUT_S_ENV_VAR = "RAY_TRAIN_WORKER_HEALTH_CHECK_TIMEOUT_S" +DEFAULT_WORKER_HEALTH_CHECK_TIMEOUT_S: float = 10 * 60 + +# Timeout in seconds for the worker group to start. +WORKER_GROUP_START_TIMEOUT_S_ENV_VAR = "RAY_TRAIN_WORKER_GROUP_START_TIMEOUT_S" +DEFAULT_WORKER_GROUP_START_TIMEOUT_S: float = 30.0 + +# Timeout in seconds for `ray.train.report` to block on synchronization barriers, +# after which a timeout error will be raised. +REPORT_BARRIER_TIMEOUT_S_ENV_VAR = "RAY_TRAIN_REPORT_BARRIER_TIMEOUT_S" +DEFAULT_REPORT_BARRIER_TIMEOUT_S: float = 60 * 30 +# Time in seconds for `ray.train.report` to log a warning if it is waiting for sync +# actor notification of releasing. +REPORT_BARRIER_WARN_INTERVAL_S_ENV_VAR = "RAY_TRAIN_REPORT_BARRIER_WARN_INTERVAL_S" +DEFAULT_REPORT_BARRIER_WARN_INTERVAL_S: float = 60 + +# The environment variable to enable the Ray Train Metrics. +METRICS_ENABLED_ENV_VAR = "RAY_TRAIN_METRICS_ENABLED" + +# Environment variable to enable the print function patching. +ENABLE_PRINT_PATCH_ENV_VAR = "RAY_TRAIN_ENABLE_PRINT_PATCH" +DEFAULT_ENABLE_PRINT_PATCH = "1" + +# Whether or not to run the controller as an actor. +RUN_CONTROLLER_AS_ACTOR_ENV_VAR = "RAY_TRAIN_RUN_CONTROLLER_AS_ACTOR" +DEFAULT_RUN_CONTROLLER_AS_ACTOR = "1" + +# V2 feature flag. +V2_ENABLED_ENV_VAR = "RAY_TRAIN_V2_ENABLED" + + +def is_v2_enabled() -> bool: + return env_bool(V2_ENABLED_ENV_VAR, False) + + +ENV_VARS_TO_PROPAGATE = { + V2_ENABLED_ENV_VAR, + HEALTH_CHECK_INTERVAL_S_ENV_VAR, + WORKER_HEALTH_CHECK_TIMEOUT_S_ENV_VAR, + WORKER_GROUP_START_TIMEOUT_S_ENV_VAR, + ENABLE_PRINT_PATCH_ENV_VAR, +} + + +def get_env_vars_to_propagate() -> Dict[str, str]: + """Returns a dictionary of environment variables that should be propagated + from the driver to the controller, and then from the controller + to each training worker. + + This way, users only need to set environment variables in one place + when launching the script instead of needing to manually set a runtime environment. + """ + env_vars = {} + for env_var in ENV_VARS_TO_PROPAGATE: + if env_set_by_user(env_var): + env_vars[env_var] = os.environ[env_var] + return env_vars diff --git a/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/exceptions.py b/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/exceptions.py new file mode 100644 index 0000000000000000000000000000000000000000..66072568392615853d71e997feba76d5b36460c6 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/exceptions.py @@ -0,0 +1,170 @@ +import os +from typing import Dict, List, Optional + +from ray.train.v2._internal.constants import ( + DEFAULT_WORKER_GROUP_START_TIMEOUT_S, + DEFAULT_WORKER_HEALTH_CHECK_TIMEOUT_S, + REPORT_BARRIER_TIMEOUT_S_ENV_VAR, + WORKER_GROUP_START_TIMEOUT_S_ENV_VAR, + WORKER_HEALTH_CHECK_TIMEOUT_S_ENV_VAR, +) + + +# TODO: Distinguish between user and system exceptions. +class RayTrainError(Exception): + """Base class for all Ray Train exceptions.""" + + +class WorkerHealthCheckTimeoutError(RayTrainError): + """Exception raised when a worker health check hangs for long enough.""" + + def __init__(self, message): + timeout = os.getenv( + WORKER_HEALTH_CHECK_TIMEOUT_S_ENV_VAR, DEFAULT_WORKER_HEALTH_CHECK_TIMEOUT_S + ) + message += ( + f"\nSet the {WORKER_HEALTH_CHECK_TIMEOUT_S_ENV_VAR} " + "environment variable to increase the timeout " + f"(current value: {timeout} seconds)." + ) + super().__init__(message) + + +class WorkerHealthCheckFailedError(RayTrainError): + """Exception raised when a worker health check fails.""" + + def __init__(self, message, failure: Exception): + super().__init__(message) + self._message = message + self.health_check_failure = failure + + def __reduce__(self): + return (self.__class__, (self._message, self.health_check_failure)) + + +class TrainingFailedError(RayTrainError): + """Exception raised when training fails.""" + + def __init__(self, worker_failures: Dict[int, Exception]): + super().__init__( + "Training failed due to worker errors. " + "Please inspect the error logs above, " + "or access the latest worker failures in this " + "exception's `worker_failures` attribute." + ) + self.worker_failures = worker_failures + + def __reduce__(self): + return (self.__class__, (self.worker_failures,)) + + +class WorkerGroupStartupTimeoutError(RayTrainError): + """Exception raised when the worker group startup times out. + + Example scenario: 4 GPUs are detected in the cluster, but when the worker + are actually scheduled, one of the nodes goes down and only 3 GPUs are + available. One of the worker tasks may be stuck pending, until a timeout is reached. + """ + + def __init__(self, num_workers: int): + timeout = float( + os.environ.get( + WORKER_GROUP_START_TIMEOUT_S_ENV_VAR, + DEFAULT_WORKER_GROUP_START_TIMEOUT_S, + ) + ) + self.num_workers = num_workers + super().__init__( + f"The worker group startup timed out after {timeout} seconds waiting " + f"for {num_workers} workers. " + "Potential causes include: " + "(1) temporary insufficient cluster resources while waiting for " + "autoscaling (ignore this warning in this case), " + "(2) infeasible resource request where the provided `ScalingConfig` " + "cannot be satisfied), " + "and (3) transient network issues. " + f"Set the {WORKER_GROUP_START_TIMEOUT_S_ENV_VAR} " + "environment variable to increase the timeout." + ) + + def __reduce__(self): + return (self.__class__, (self.num_workers,)) + + +class WorkerGroupStartupFailedError(RayTrainError): + """Exception raised when the worker group fails to start. + + Example scenario: A worker is scheduled onto a node that dies while + the worker actor is initializing. + """ + + +class CheckpointManagerInitializationError(RayTrainError): + """Exception raised when the checkpoint manager fails to initialize from a snapshot. + + Example scenarios: + 1. The checkpoint manager snapshot version is old and + incompatible with the current version of Ray Train. + 2. The checkpoint manager snapshot JSON file is corrupted. + 3. The checkpoint manager snapshot references checkpoints that cannot be found + in the run storage path. + """ + + +class CollectiveTimeoutError(RayTrainError): + """Exception raised when an internal Ray Train collective operation of + the worker group times out. + """ + + +class BroadcastCollectiveTimeoutError(CollectiveTimeoutError): + """Exception raised when the broadcast operation times out. + + There are two main timeout examples: + 1. If not all workers call `ray.train.report`, the entire worker group will + hang until the timeout before raising. This prevents indefinite worker + group hangs. + 2. If a worker is slow in the training loop and fails to reach the broadcast + time, the collective will time out. + """ + + def __init__( + self, time_elapsed: Optional[float], missing_ranks: List[int], timeout_s: float + ): + self._time_elapsed = time_elapsed + self._missing_ranks = missing_ranks + self._timeout_s = timeout_s + + message = ( + f"The broadcast operation timed out after {time_elapsed:.2f} seconds. " + "Please make sure all worker ranks call `ray.train.report`. \n" + f"The following ranks have not called it: {missing_ranks}\n" + f"You can set this timeout with the {REPORT_BARRIER_TIMEOUT_S_ENV_VAR} " + f"environment variable (current value: {timeout_s:.2f} s)." + ) + super().__init__(message) + + def __reduce__(self): + return ( + self.__class__, + (self._time_elapsed, self._missing_ranks, self._timeout_s), + ) + + +class UserExceptionWithTraceback(RayTrainError): + """This class wraps a user code exception raised on the worker + with its original traceback string, for logging and debugging purposes. + + This is needed because the original exception traceback is not serialized + with the exception when it is *returned* back to the main process. + """ + + def __init__(self, exc: BaseException, traceback_str: str): + self._base_exc = exc + self._traceback_str = traceback_str + + def __reduce__(self): + return (self.__class__, (self._base_exc, self._traceback_str)) + + def __str__(self): + return self._traceback_str diff --git a/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/__init__.py b/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/__pycache__/callback.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/__pycache__/callback.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c88d9bdbd562fca9b418455fc3c415e5774e3cdf Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/__pycache__/callback.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/__pycache__/context.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/__pycache__/context.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c2d6c3f76f7b64b7c6bf26868d20dfa71180701b Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/__pycache__/context.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/__pycache__/controller.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/__pycache__/controller.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..739f774f5c03b9b177a277d8569a26dbb3345936 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/__pycache__/controller.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/__pycache__/storage.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/__pycache__/storage.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d6c26dcdb32a1df4f9733951f70621fafd8aad3f Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/__pycache__/storage.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/callback.py b/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/callback.py new file mode 100644 index 0000000000000000000000000000000000000000..a2a6ada03348b73f838086f2fc5daa05b8ba2ab3 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/callback.py @@ -0,0 +1,140 @@ +from contextlib import contextmanager +from typing import TYPE_CHECKING, Any, Dict, List, Optional + +from ray.train.v2.api.callback import RayTrainCallback +from ray.util.annotations import DeveloperAPI + +if TYPE_CHECKING: + from ray.train import Checkpoint + from ray.train.v2._internal.execution.controller import TrainControllerState + from ray.train.v2._internal.execution.failure_handling import FailureDecision + from ray.train.v2._internal.execution.scaling_policy import ScalingDecision + from ray.train.v2._internal.execution.worker_group import ( + WorkerGroup, + WorkerGroupStatus, + ) + + +@DeveloperAPI +class WorkerGroupCallback(RayTrainCallback): + def before_init_train_context( + self, worker_group: "WorkerGroup" + ) -> Dict[str, List[Any]]: + """Called before initializing the TrainContext for the worker_group. + + Return: + A dictionary of additional arguments for TrainContext. + The key is the argument name and the value is a list of argument values + to pass to the TrainContext constructor of each worker in the worker group. + """ + return {} + + @contextmanager + def on_worker_group_start(self): + yield + + def after_worker_group_start(self, worker_group: "WorkerGroup"): + """Called after the worker group actors are initialized. + All workers should be ready to execute tasks.""" + pass + + def after_worker_group_training_start(self, worker_group: "WorkerGroup"): + pass + + @contextmanager + def on_worker_group_shutdown(self): + yield + + def before_worker_group_shutdown(self, worker_group: "WorkerGroup"): + """Called before the worker group is shut down. + Workers may be dead at this point due to actor failures, so this method + should catch and handle exceptions if attempting to execute tasks.""" + pass + + def after_worker_group_poll_status(self, worker_group_status: "WorkerGroupStatus"): + pass + + +@DeveloperAPI +class ControllerCallback(RayTrainCallback): + def after_controller_start(self): + """Called immediately after `TrainController.run` is called, + before the control loop starts executing.""" + pass + + def before_controller_shutdown(self): + """Called before `TrainController.run` exits, + after the control loop has exited.""" + pass + + def after_controller_state_update( + self, + previous_state: "TrainControllerState", + current_state: "TrainControllerState", + ): + """Called whenever the controller state is updated.""" + pass + + def before_controller_execute_failure_decision( + self, + failure_decision: "FailureDecision", + worker_group_status: "WorkerGroupStatus", + ): + """Called before the controller executes a failure decision.""" + pass + + def before_controller_execute_scaling_decision( + self, + scaling_decision: "ScalingDecision", + worker_group_status: "WorkerGroupStatus", + ): + """Called before the controller executes a scaling decision.""" + pass + + +@DeveloperAPI +class ReportCallback(RayTrainCallback): + def after_report( + self, metrics: List[Dict[str, Any]], checkpoint: Optional["Checkpoint"] + ): + """Called after all workers have reported a training result. + + Note that this differs from `after_worker_group_poll_status`, + which may only contain a subset of workers that have reported. + For example, if only rank 0 is performing checkpointing, then + rank 0 would report a training result the slowest. + """ + pass + + +@DeveloperAPI +class WorkerCallback(RayTrainCallback): + """ + Callbacks that are hooked to the worker event. + + These callbacks are created on the train driver process and then + copied and passed to all the workers. + The execution of these callbacks happens on each of the workers, + not on the train driver process. + """ + + def after_init_train_context(self): + pass + + def before_worker_shutdown(self): + pass + + +@DeveloperAPI +class TrainContextCallback(RayTrainCallback): + """ + Callbacks that are hooked to the train context event. + + These callbacks are created on the train driver process and then + copied and passed to all the workers. + The execution of these callbacks happens on the train context of the workers. + """ + + @contextmanager + def on_report(self): + yield diff --git a/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/checkpoint/__init__.py b/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/checkpoint/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/checkpoint/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/checkpoint/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..052fc63826209b38a5d578434d7bbdf69d388149 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/checkpoint/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/checkpoint/__pycache__/checkpoint_manager.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/checkpoint/__pycache__/checkpoint_manager.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2251940ceeea22c3daa9f05ff3783f35f341198f Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/checkpoint/__pycache__/checkpoint_manager.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/checkpoint/__pycache__/report_handler.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/checkpoint/__pycache__/report_handler.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3645e6275e34a1adf3a211f38f5e3f87ad4eefaa Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/checkpoint/__pycache__/report_handler.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/checkpoint/__pycache__/sync_actor.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/checkpoint/__pycache__/sync_actor.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..13efbe802b3c65ad795da0e0848c291021682e2c Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/checkpoint/__pycache__/sync_actor.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/checkpoint/checkpoint_manager.py b/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/checkpoint/checkpoint_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..688c6168f70db04f02e9567880e659647b95d734 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/checkpoint/checkpoint_manager.py @@ -0,0 +1,271 @@ +import logging +from typing import Any, Dict, List, Optional + +from ray.air.config import CheckpointConfig +from ray.train._checkpoint import Checkpoint +from ray.train._internal.checkpoint_manager import ( + _CheckpointManager, + _insert_into_sorted_list, +) +from ray.train._internal.session import _TrainingResult +from ray.train.v2._internal.exceptions import CheckpointManagerInitializationError +from ray.train.v2._internal.execution.callback import ReportCallback +from ray.train.v2._internal.execution.context import StorageContext +from ray.train.v2._internal.execution.storage import _delete_fs_path, _exists_at_fs_path + +try: + from pydantic import BaseModel + from pydantic_core import from_json +except (ImportError, ModuleNotFoundError) as exc: + raise ImportError( + "`ray.train.v2` requires the pydantic package, which is missing. " + "Run the following command to fix this: `pip install pydantic`" + ) from exc + + +logger = logging.getLogger(__name__) + + +class _TrainingResultState(BaseModel): + # Increment version if the schema changes + version: int = 0 + checkpoint_dir_name: str + metrics: dict + + +class _CheckpointManagerState(BaseModel): + # Increment version if the schema changes + version: int = 0 + checkpoint_results: List[_TrainingResultState] + latest_checkpoint_result: Optional[_TrainingResultState] + + +def _get_training_result_from_state( + state: _TrainingResultState, + storage_context: StorageContext, +) -> _TrainingResult: + """Get a TrainingResult object from a Pydantic state object.""" + return _TrainingResult( + checkpoint=Checkpoint( + path=storage_context.build_checkpoint_path_from_name( + state.checkpoint_dir_name + ), + filesystem=storage_context.storage_filesystem, + ), + metrics=state.metrics, + ) + + +def _get_state_from_training_result( + training_result: _TrainingResult, + storage_context: StorageContext, +) -> _TrainingResultState: + """Get a Pydantic state object from a TrainingResult object.""" + return _TrainingResultState( + checkpoint_dir_name=storage_context.extract_checkpoint_dir_name_from_path( + training_result.checkpoint.path + ), + metrics=training_result.metrics, + ) + + +class CheckpointManager(_CheckpointManager, ReportCallback): + def __init__( + self, + checkpoint_config: CheckpointConfig, + storage_context: StorageContext, + ): + self._storage_context = storage_context + self._checkpoint_config = checkpoint_config + super().__init__(checkpoint_config) + # If the snapshot is found, the checkpoint manager will restore its state. + self._maybe_load_state_from_storage() + + def register_checkpoint(self, checkpoint_result: _TrainingResult): + """Register new checkpoint and add to bookkeeping. + + This method will register a new checkpoint and add it to the internal + bookkeeping logic. This means the checkpoint manager will decide if + this checkpoint should be kept, and if older or worse performing + checkpoints should be deleted. + + Args: + checkpoint: Tracked checkpoint object to add to bookkeeping. + """ + self._latest_checkpoint_result = checkpoint_result + + if self._checkpoint_config.checkpoint_score_attribute is not None: + # If we're ordering by a score, insert the checkpoint + # so that the list remains sorted. + _insert_into_sorted_list( + self._checkpoint_results, + checkpoint_result, + key=self._get_checkpoint_score, + ) + else: + # If no metric is provided, just append (ordering by time of registration). + self._checkpoint_results.append(checkpoint_result) + + results_to_delete = {} + if self._checkpoint_config.num_to_keep is not None: + # Delete the bottom (N - K) checkpoints + worst_results = set( + self._checkpoint_results[: -self._checkpoint_config.num_to_keep] + ) + # Except for the latest checkpoint. + results_to_delete = worst_results - {self._latest_checkpoint_result} + + # Update internal state before actually deleting them. + self._checkpoint_results = [ + checkpoint_result + for checkpoint_result in self._checkpoint_results + if checkpoint_result not in results_to_delete + ] + + # Save the checkpoint manager state to storage. + # Note: We save the state before deleting the old checkpoints. + # If deletion happens first and the process crashes, our snapshot + # may point to some stale checkpoints that are already deleted. + # TODO: Make this writing operation non-blocking. + self._write_state_to_storage() + + # Delete the old checkpoints. + for checkpoint_result in results_to_delete: + checkpoint = checkpoint_result.checkpoint + logger.debug("Deleting checkpoint: ", checkpoint) + _delete_fs_path(fs=checkpoint.filesystem, fs_path=checkpoint.path) + + # -------------------------- + # CheckpointManager state + # -------------------------- + + def _save_state(self) -> str: + """Save the checkpoint manager state to a JSON str.""" + + checkpoint_results = [ + _get_state_from_training_result(checkpoint_result, self._storage_context) + for checkpoint_result in self._checkpoint_results + ] + + latest_checkpoint_result = ( + _get_state_from_training_result( + self._latest_checkpoint_result, self._storage_context + ) + if self._latest_checkpoint_result is not None + else None + ) + + manager_snapshot = _CheckpointManagerState( + checkpoint_results=checkpoint_results, + latest_checkpoint_result=latest_checkpoint_result, + ) + return manager_snapshot.model_dump_json() + + def _load_state(self, json_state: str): + """Load the checkpoint manager state from a JSON str.""" + try: + manager_snapshot = _CheckpointManagerState.model_validate( + from_json(json_state) + ) + except Exception as e: + raise CheckpointManagerInitializationError(repr(e)) from e + self._assert_checkpoints_exist() + + self._checkpoint_results = [ + _get_training_result_from_state( + training_result_state, self._storage_context + ) + for training_result_state in manager_snapshot.checkpoint_results + ] + + self._latest_checkpoint_result = ( + _get_training_result_from_state( + manager_snapshot.latest_checkpoint_result, self._storage_context + ) + if manager_snapshot.latest_checkpoint_result is not None + else None + ) + + def _maybe_load_state_from_storage(self): + """Load the checkpoint manager state from storage. + If no snapshot is found, start with a clean state. + """ + if not _exists_at_fs_path( + fs=self._storage_context.storage_filesystem, + fs_path=self._storage_context.checkpoint_manager_snapshot_path, + ): + logger.debug( + "No checkpoint manager snapshot found. " + "No checkpoint will be available via `ray.train.get_checkpoint`, " + "so training will start from scratch." + ) + return + with self._storage_context.storage_filesystem.open_input_stream( + self._storage_context.checkpoint_manager_snapshot_path + ) as f: + logger.info( + "A run snapshot was found in storage folder at: " + f"'{self._storage_context.experiment_fs_path}'\n" + "This snapshot contains a list of checkpoints reported via " + "`ray.train.report` and will be loaded. " + "This allows the latest checkpoint found in the snapshot to be " + "accessible within your training function via " + "`ray.train.get_checkpoint`.\n" + "If you meant to start a brand new training job without any " + "information about previous checkpoints found in this directory, " + "please configure a new, unique `RunConfig(name)` or delete the " + f"existing folder at '{self._storage_context.experiment_fs_path}'." + ) + json_state = f.read().decode("utf-8") + self._load_state(json_state) + + def _write_state_to_storage(self): + """Write the checkpoint manager state to storage.""" + checkpoint_manager_snapshot = self._save_state() + with self._storage_context.storage_filesystem.open_output_stream( + self._storage_context.checkpoint_manager_snapshot_path + ) as f: + f.write(checkpoint_manager_snapshot.encode("utf-8")) + + def _assert_checkpoints_exist(self): + """Validate the checkpoint manager state. + + This method will validate the checkpoint manager state by checking if + the checkpoints specified in manager snapshot is compatible with the + checkpoint folders of the experiment storage filesystem. + + Raises: + CheckpointManagerInitializationError: If the checkpoint manager snapshot + is not consistent with the stored checkpoints. + """ + for checkpoint_result in self._checkpoint_results: + checkpoint = checkpoint_result.checkpoint + assert checkpoint is not None + if not _exists_at_fs_path( + fs=checkpoint.filesystem, fs_path=checkpoint.path + ): + raise CheckpointManagerInitializationError( + message=( + "The run snapshot contains a reference to a checkpoint " + f"that does not exist anymore ({checkpoint}). You are " + "running in a corrupted run directory `experiment_fs_path`." + "Please configure a new, unique `RunConfig(name)` " + "or delete the existing folder at " + f"`{self._storage_context.experiment_fs_path}`." + ) + ) + + # -------------------------- + # ReportCallback + # -------------------------- + + def after_report( + self, metrics: List[Dict[str, Any]], checkpoint: Optional[Checkpoint] + ): + if not checkpoint: + return + + rank_0_metrics = metrics[0] + self.register_checkpoint( + _TrainingResult(checkpoint=checkpoint, metrics=rank_0_metrics) + ) diff --git a/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/checkpoint/report_handler.py b/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/checkpoint/report_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..a0d9201e8e747f461f82a7793652c1bc242e5dd4 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/checkpoint/report_handler.py @@ -0,0 +1,111 @@ +from collections import deque +from typing import TYPE_CHECKING, Deque, List, Optional + +from ray.train.v2._internal.execution.callback import ( + ReportCallback, + WorkerGroupCallback, +) +from ray.train.v2._internal.execution.worker_group import WorkerGroup, WorkerGroupStatus + +if TYPE_CHECKING: + from ray.train._internal.session import _TrainingResult + + +class ReportCallbackHandler(WorkerGroupCallback): + """Consolidate training results from multiple workers and call + subscribers implementing the `ReportCallback` interface sequentially. + """ + + def __init__(self, report_callbacks: List[ReportCallback]): + # Number of workers in the current worker group. It is initialized + # to be None. It is set to the number of workers when it receives the + # worker group status for the first time. + # When a worker group shutdown, self._num_workers is set to None, + # waiting to be updated when a new worker group status is received again. + self._num_workers: Optional[int] = None + # A list of queues holding training results from workers. + self._training_result_queues: Optional[List[Deque[_TrainingResult]]] = None + + self._report_callbacks = report_callbacks + + # -------------------------- + # WorkerGroupCallback + # -------------------------- + + def after_worker_group_poll_status( + self, worker_group_status: WorkerGroupStatus + ) -> None: + """Handle training results as they roll in from worker status polls. + + Wait for all workers to report training results to collect + a consolidated training result. + """ + # Step 1: If self._num_workers is None, we need to initialize the number + # of workers and training_results_queues from the worker group status. This + # happens when the handler receives the worker group status for the first time. + assert ( + self._num_workers and self._training_result_queues + ), "Need to call initialize state with `after_worker_group_start` first." + + assert self._num_workers == worker_group_status.num_workers, ( + f"The number of workers in the worker group has changed unexpectedly. " + f"Expected: {self._num_workers}, got: {worker_group_status.num_workers}" + ) + + # Step 2: Update training_results_queues with poll_results. + for i in range(self._num_workers): + training_result = worker_group_status.worker_statuses[i].training_result + if training_result: + self._training_result_queues[i].append(training_result) + + # Directly return if any of the worker result queues are empty. + if not all(self._training_result_queues): + return + + training_results = [q.popleft() for q in self._training_result_queues] + + # Step 3: Consolidate a list of checkpoints to single checkpoint. + # Use the first checkpoint as the consolidated checkpoint. + checkpoint_results = [ + tr for tr in training_results if tr.checkpoint is not None + ] + + consolidated_checkpoint = None + if checkpoint_results: + # Double check the storage path of the checkpoints in the training results. + unique_checkpoint_paths = {tr.checkpoint.path for tr in checkpoint_results} + if len(unique_checkpoint_paths) > 1: + # TODO: Support for inconsistent checkpoints path from workers + # instead of hard raising error. Maybe drop this iteration of + # training results and continue with the next iteration. + raise RuntimeError( + "The storage path of the checkpoints in the training results " + "is not the same. This means the checkpoints are not consistent." + "Got a mix of the following checkpoint paths: " + f"{unique_checkpoint_paths}\n" + "This is unexpected -- please file a Github issue." + ) + consolidated_checkpoint = checkpoint_results[0].checkpoint + + # Step 4: Invoke all dependent `ReportCallback`s. + metrics_per_worker = [ + training_result.metrics for training_result in training_results + ] + for callback in self._report_callbacks: + callback.after_report( + metrics=metrics_per_worker, + checkpoint=consolidated_checkpoint, + ) + + def after_worker_group_start(self, worker_group: WorkerGroup) -> None: + """Handle worker group start. Initialize internal states.""" + self._num_workers = len(worker_group) + self._training_result_queues = [deque() for _ in range(self._num_workers)] + + def before_worker_group_shutdown(self, worker_group: WorkerGroup) -> None: + """Handle worker group shutdown. Clear internal states. + + None of the partial reported results are valid at this point. + """ + self._num_workers = None + self._training_result_queues = None diff --git a/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/checkpoint/sync_actor.py b/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/checkpoint/sync_actor.py new file mode 100644 index 0000000000000000000000000000000000000000..f467290c71bf8c4b908ee66c25fb4796fc5eb133 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/checkpoint/sync_actor.py @@ -0,0 +1,190 @@ +import asyncio +import logging +from contextlib import contextmanager +from typing import List, Optional, TypeVar + +import ray +from ray.train.v2._internal.constants import ( + DEFAULT_REPORT_BARRIER_TIMEOUT_S, + DEFAULT_REPORT_BARRIER_WARN_INTERVAL_S, + REPORT_BARRIER_WARN_INTERVAL_S_ENV_VAR, +) +from ray.train.v2._internal.exceptions import BroadcastCollectiveTimeoutError + +T = TypeVar("T", bound=Optional[object]) +logger = logging.getLogger(__name__) + + +BROADCAST_PERIODIC_WARNING = """ +`ray.train.report` has not been called by all {world_size} workers in the group. + +The workers have been waiting for {max_time_elapsed_s:.2f} s for the following ranks +to join the `report` call: {missing_ranks}. + +Please ensure that all workers call `ray.train.report` regardless of whether +they participate in checkpointing or not (e.g., pass `checkpoint=None` for ranks +that do not save a checkpoint). Also ensure that workers are not hanging on +other operations, causing them to miss this synchronization barrier. + +You can set the {warn_interval_env_var} environment variable to change the frequency +of this warning (current value: {warn_interval_s} s). +""" + + +@ray.remote(num_cpus=0) # type: ignore +class SynchronizationActor: + """A Ray actor that synchronizes the workers in a distributed training job. + + This actor forms a synchronization barrier on a group of processes. + Every time a worker calls the broadcast_from_rank_zero method, + the counter is incremented. When the counter equals to the world size, + the actor notifies all the workers to continue. + """ + + def __init__( + self, + timeout_s: float = DEFAULT_REPORT_BARRIER_TIMEOUT_S, + warn_interval_s: float = DEFAULT_REPORT_BARRIER_WARN_INTERVAL_S, + ): + self._counter: int = 0 + self._world_size: int = 0 + self._condition = asyncio.Condition() + self._reduced_data = None + # The time when workers from different ranks + # enters the synchronization barrier. + self._sync_start_times: List[Optional[float]] = [] + # The timeout in seconds for the synchronization barrier. + self._timeout_s: float = timeout_s + # The interval in seconds to log a warning when waiting for the barrier. + self._warn_interval_s: float = warn_interval_s + + def get_counter(self): + """Returns the current value of the counter.""" + return self._counter + + def get_world_size(self): + """Returns the current value of the world_size.""" + return self._world_size + + def get_reduced_data(self): + """Returns the current value of the reduced_data.""" + return self._reduced_data + + def _clear_states(self): + """Clears the states of the actor. When the last worker has + called the _clear_states method, the actor clears its states + """ + self._counter -= 1 + if self._counter == 0: + self._reduced_data = None + self._world_size = 0 + + def _setup_or_validate_collective_op(self, world_size: int): + """The setup method for the synchronization actor if it is not setup yet. + It initializes the world size and the start times for the + synchronization barrier. + """ + if self._world_size == 0: + self._world_size = world_size + self._sync_start_times = [None] * world_size + elif world_size != self._world_size: + raise ValueError( + f"Expected all callers to provide the same world size. \ + Got {world_size} and expected {self._world_size}." + ) + + @contextmanager + def _broadcast_collective_context_manager( + self, world_rank: int, world_size: int, data: T + ): + """A context manager that ensures the synchronization barrier is lifted + after the block of code is executed. + """ + try: + self._setup_or_validate_collective_op(world_size) + if world_rank == 0: + self._reduced_data = data + if self._counter < self._world_size: + self._counter += 1 + yield + finally: + self._clear_states() + + def _get_time_elapsed(self) -> Optional[float]: + """Return the time elapsed since the first worker entered the barrier. + If no workers have entered the barrier, returns None. + """ + start_times = [t for t in self._sync_start_times if t is not None] + if not start_times: + return None + + return asyncio.get_event_loop().time() - min(start_times) + + def _get_missing_ranks(self) -> List[int]: + """Returns the ranks that have not entered the synchronization barrier.""" + return [i for i, t in enumerate(self._sync_start_times) if t is None] + + async def _wait_with_logging(self, condition, world_rank: int): + """Waits for the condition to be notified, logging an warning every + `log_interval` seconds, and raises a timeout error if `timeout` is reached. + """ + current_time = asyncio.get_event_loop().time() + self._sync_start_times[world_rank] = current_time + while True: + try: + await asyncio.wait_for(condition.wait(), timeout=self._warn_interval_s) + return + # asyncio.wait_for() raises `asyncio.TimeoutError` for asyncio<=3.10 + # and raises `TimeoutError` for asyncio>=3.11 + # https://docs.python.org/3/library/asyncio-task.html#asyncio.wait_for + # TODO: (hpguo) Make only one worker log the warning message. + except (asyncio.TimeoutError, TimeoutError): + logger.warning( + BROADCAST_PERIODIC_WARNING.format( + world_size=self._world_size, + max_time_elapsed_s=self._get_time_elapsed(), + missing_ranks=self._get_missing_ranks(), + warn_interval_env_var=REPORT_BARRIER_WARN_INTERVAL_S_ENV_VAR, + warn_interval_s=self._warn_interval_s, + ) + ) + + async def broadcast_from_rank_zero( + self, world_rank: int, world_size: int, data: T + ) -> T: + """Broadcasts a data from the worker with rank 0 to all other workers. + + This method is a coroutine that blocks until all workers have called this + method with the their data. The data from the worker with rank 0 will + be returned. + """ + # Ensures that all global states manipulation is done within the async context + # manager which makes the condition variable awaiting and the counter + # incrementing an atomic operation. + async with self._condition: + with self._broadcast_collective_context_manager( + world_rank, world_size, data + ): + # If the counter is equal to the world size, it means the last worker + # has called the broadcast_from_rank_zero method. The actor notifies + # all the workers to continue. + if self._counter == self._world_size: + self._condition.notify_all() + return self._reduced_data + # If the counter is less than the world size, the actor waits for the + # other workers to call the broadcast_from_rank_zero method. + try: + await asyncio.wait_for( + self._wait_with_logging(self._condition, world_rank), + timeout=self._timeout_s, + ) + return self._reduced_data + except (asyncio.TimeoutError, TimeoutError) as e: + raise BroadcastCollectiveTimeoutError( + time_elapsed=self._get_time_elapsed(), + missing_ranks=self._get_missing_ranks(), + timeout_s=self._timeout_s, + ) from e + + # TODO: Implement a general consensus_from_votes method that takes a callable + # reduce_fn and a list of votes from each worker. The method returns the consensus diff --git a/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/context.py b/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/context.py new file mode 100644 index 0000000000000000000000000000000000000000..5b7dae52a46adb9defdeb4fa2796e42ddab4cb73 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/context.py @@ -0,0 +1,281 @@ +import logging +import threading +from dataclasses import dataclass +from queue import Queue +from typing import TYPE_CHECKING, Any, Dict, List, Optional + +import ray +from ray.data.iterator import DataIterator +from ray.train import Checkpoint +from ray.train._internal import session +from ray.train._internal.session import _TrainingResult +from ray.train.v2._internal.execution.checkpoint.sync_actor import SynchronizationActor +from ray.train.v2._internal.execution.storage import StorageContext +from ray.train.v2._internal.util import _copy_doc, invoke_context_managers +from ray.train.v2.api.config import RunConfig + +if TYPE_CHECKING: + from ray.train.v2._internal.execution.callback import TrainContextCallback + from ray.train.v2._internal.execution.worker_group.thread_runner import ThreadRunner + + +logger = logging.getLogger(__file__) + + +@dataclass +class TrainRunContext: + """Holds the metadata and context for the current training run.""" + + # TODO: Make this dataclass immutable after refactoring the train context. + + # The run configuration for the current training run. + run_config: RunConfig + + # TODO: Add more fields that are shared across all workers and controllers. + # For example, StorageContext, ScalingConfig, etc. + + def get_run_config(self) -> RunConfig: + """Returns the run config of the current training run.""" + return self.run_config + + +@dataclass(frozen=True) +class DistributedContext: + world_rank: int + world_size: int + local_rank: int + local_world_size: int + node_rank: int + + +@dataclass(frozen=True) +class ExecutionContext: + """Holds the execution context for the current worker process. + + Every worker process has a single execution context accessed via the + `TrainContext`, which includes the training thread that is actually + running the user code. + """ + + # A shared synchronization actor that helps broadcast data across ranks. + synchronization_actor: SynchronizationActor + + # A queue that receives training results from the user training code. + # `ray.train.report` in user code populates this queue. + result_queue: Queue + + # The thread launcher that runs the user training loop. + training_thread_runner: "ThreadRunner" + + # The callbacks that are run in the worker train context. + train_context_callbacks: List["TrainContextCallback"] + + +@dataclass +class TrainContext(TrainRunContext): + distributed_context: DistributedContext + execution_context: ExecutionContext + storage_context: StorageContext + dataset_shards: Dict[str, DataIterator] + checkpoint: Optional[Checkpoint] = None + + @_copy_doc(session.get_metadata) + def get_metadata(self) -> Dict[str, Any]: + raise NotImplementedError + + @_copy_doc(session.get_experiment_name) + def get_experiment_name(self) -> str: + # TODO: Resolve run_config.name if it is None + return self.run_config.name + + @_copy_doc(session.get_trial_name) + def get_trial_name(self) -> str: + raise NotImplementedError + + @_copy_doc(session.get_trial_id) + def get_trial_id(self) -> str: + raise NotImplementedError + + @_copy_doc(session.get_trial_resources) + def get_trial_resources(self): + raise NotImplementedError + + @_copy_doc(session.get_trial_dir) + def get_trial_dir(self) -> str: + raise NotImplementedError + + @_copy_doc(session.get_world_size) + def get_world_size(self) -> int: + return self.distributed_context.world_size + + @_copy_doc(session.get_world_rank) + def get_world_rank(self) -> int: + return self.distributed_context.world_rank + + @_copy_doc(session.get_local_rank) + def get_local_rank(self) -> int: + return self.distributed_context.local_rank + + @_copy_doc(session.get_local_world_size) + def get_local_world_size(self) -> int: + return self.distributed_context.local_world_size + + @_copy_doc(session.get_node_rank) + def get_node_rank(self) -> int: + return self.distributed_context.node_rank + + @_copy_doc(session.get_storage) + def get_storage(self): + return self.storage_context + + def get_result_queue(self): + return self.execution_context.result_queue + + def get_synchronization_actor(self): + return self.execution_context.synchronization_actor + + def get_checkpoint(self): + return self.checkpoint + + def get_dataset_shard(self, dataset_name: str) -> DataIterator: + """Returns the :class:`ray.data.DataIterator` shard for this worker. + + Call :meth:`~ray.data.DataIterator.iter_torch_batches` or + :meth:`~ray.data.DataIterator.to_tf` on this shard to convert it to the + appropriate framework-specific data type. + + Args: + dataset_name: Name of the dataset shard. + Returns: + The ``DataIterator`` shard with the given name for this worker. + Raises: + KeyError: If the dataset shard with the given name is not found. + """ + try: + return self.dataset_shards[dataset_name] + except KeyError: + raise KeyError( + f"Dataset {dataset_name} not found. Available datasets: " + f"{list(self.dataset_shards.keys())}." + ) + + def get_context_callbacks(self) -> List["TrainContextCallback"]: + return self.execution_context.train_context_callbacks + + def _sync_checkpoint_dir_name_across_ranks( + self, checkpoint_dir_name: Optional[str] = None + ) -> str: + """Sync the checkpoint dir name across ranks. + + Args: + checkpoint_dir_name: The checkpoint dir name to sync. + + Returns: + The synced checkpoint dir name. + """ + # If checkpoint_dir_name is not set, use default checkpoint_dir_name + # created by the storage context. + checkpoint_dir_name = ( + checkpoint_dir_name + or self.storage_context.make_default_checkpoint_dir_name() + ) + # Get a consensus across ranks on the remote storage path, so distributed + # checkpoints will be stored to the same place. + sync_actor = self.get_synchronization_actor() + return ray.get( + sync_actor.broadcast_from_rank_zero.remote( + world_rank=self.distributed_context.world_rank, + world_size=self.distributed_context.world_size, + data=checkpoint_dir_name, + ) + ) + + def _save_checkpoint( + self, + checkpoint_dir_name: str, + metrics: Dict[str, Any], + checkpoint: Optional[Checkpoint] = None, + ) -> _TrainingResult: + """Save the checkpoint to remote storage. + + Returns: + The training result object containing the persisted checkpoint. + """ + + if not checkpoint: + return _TrainingResult(checkpoint=None, metrics=metrics) + + # Persist the checkpoint to the remote storage path. + persisted_checkpoint = self.storage_context.persist_current_checkpoint( + checkpoint, checkpoint_dir_name + ) + # Update latest checkpoint as the persisted checkpoint. + self.checkpoint = persisted_checkpoint + + return _TrainingResult(checkpoint=persisted_checkpoint, metrics=metrics) + + def report( + self, + metrics: Dict[str, Any], + checkpoint: Optional[Checkpoint] = None, + checkpoint_dir_name: Optional[str] = None, + ): + """ + Upload checkpoint to remote storage and put a training + result on the result queue of this worker process. + + Args: + metrics: The metrics to report. + checkpoint: The checkpoint to report. + checkpoint_dir_name: The name of the checkpoint dir + in this iteration. Note: If not set, the checkpoint will + be stored in the default storage path. If set, make sure + this value is unique for each iteration. + + TODO: the report function should be implemented in the worker instead + of in the train context. The train context should only keep the train + related information and not the worker related actions. This refactor + would also require the `TrainContextCallback` to be updated as well. + """ + + with invoke_context_managers( + [ + callback.on_report + for callback in self.execution_context.train_context_callbacks + ] + ): + # Step 1: sync the checkpoint dir name across ranks. + checkpoint_dir_name = self._sync_checkpoint_dir_name_across_ranks( + checkpoint_dir_name + ) + # Step 2: save the checkpoint to remote storage. + training_result = self._save_checkpoint( + checkpoint_dir_name, metrics, checkpoint + ) + # Step 3: Report the training result to the result queue. + # The queue size is set to 1 to avoid accumulating unprocessed results. + # If the queue is full, the put operation blocks until a result is consumed. + + # TODO (hpguo): Add a metrics to track the blocking time waiting for the + # training result to be consumed by the controller. + self.get_result_queue().put(training_result) + + +# The global variable holding the current TrainContext +_train_context: Optional[TrainContext] = None + +# Thread lock to protect the global TrainContext +_context_lock = threading.Lock() + + +def get_train_context() -> TrainContext: + with _context_lock: + if _train_context is None: + raise RuntimeError("TrainContext has not been initialized.") + return _train_context + + +def set_train_context(context) -> None: + global _train_context + with _context_lock: + _train_context = context diff --git a/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/controller.py b/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/controller.py new file mode 100644 index 0000000000000000000000000000000000000000..c4b65dc6f6e2f2d4b789f5cbc7182df88670e6f0 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/controller.py @@ -0,0 +1,377 @@ +import logging +import os +import time +from enum import Enum +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional + +from ray._private.auto_init_hook import wrap_auto_init +from ray.train import Checkpoint +from ray.train.v2._internal.constants import ( + DEFAULT_HEALTH_CHECK_INTERVAL_S, + HEALTH_CHECK_INTERVAL_S_ENV_VAR, +) +from ray.train.v2._internal.exceptions import ( + TrainingFailedError, + WorkerGroupStartupFailedError, + WorkerGroupStartupTimeoutError, +) +from ray.train.v2._internal.execution.callback import ( + ControllerCallback, + ReportCallback, + TrainContextCallback, + WorkerCallback, + WorkerGroupCallback, +) +from ray.train.v2._internal.execution.checkpoint.checkpoint_manager import ( + CheckpointManager, +) +from ray.train.v2._internal.execution.checkpoint.report_handler import ( + ReportCallbackHandler, +) +from ray.train.v2._internal.execution.context import TrainRunContext +from ray.train.v2._internal.execution.failure_handling import ( + FailureDecision, + FailurePolicy, +) +from ray.train.v2._internal.execution.scaling_policy import ( + ResizeDecision, + ScalingDecision, + ScalingPolicy, +) +from ray.train.v2._internal.execution.storage import StorageContext, get_fs_and_path +from ray.train.v2._internal.execution.worker_group import WorkerGroup, WorkerGroupStatus +from ray.train.v2._internal.logging.logging import configure_controller_logger +from ray.train.v2._internal.util import time_monotonic +from ray.train.v2.api.result import Result +from ray.train.v2.api.callback import RayTrainCallback + +logger = logging.getLogger(__name__) + + +class TrainControllerState(Enum): + """The possible states that the training controller can be in + while running the main execution control loop. + + States: + RUNNING: The training controller is actively running training tasks. + RECOVERING: The training controller is in the process of recovering + from an error. + INITIALIZING: The train controller is starting up. + This is always the initial state of the controller. + ERRORED: A terminal state indicating that training has encountered + an error and cannot continue. + FINISHED: A terminal state indicating that training has completed. + """ + + RUNNING = "RUNNING" + INITIALIZING = "INITIALIZING" + RECOVERING = "RECOVERING" + ERRORED = "ERRORED" + FINISHED = "FINISHED" + + +class TrainController: + """Manages the execution of a distributed training job. + + Responsibilities include: + * Triggering the training function to run on the worker group. + * Monitoring the status of the worker group. + * Handling scaling decisions by restarting the worker group. + * Handling failure decisions by restarting the worker group or terminating training. + * Running callback logic on different hooks in the control loop. + """ + + worker_group_cls = WorkerGroup + + def __init__( + self, + train_fn: Callable[[Dict[str, Any]], None], + train_run_context: TrainRunContext, + scaling_policy: ScalingPolicy, + failure_policy: FailurePolicy, + callbacks: Optional[List[RayTrainCallback]] = None, + # TODO: [Deprecation] + resume_from_checkpoint: Optional[Checkpoint] = None, + ): + self._train_run_context = train_run_context + configure_controller_logger(self._train_run_context) + self._train_fn = train_fn + self._scaling_policy = scaling_policy + self._failure_policy = failure_policy + self._run_config = self._train_run_context.run_config + self._callbacks = callbacks or [] + self._resume_from_checkpoint = resume_from_checkpoint + self._storage_context = StorageContext( + storage_path=self._run_config.storage_path, + experiment_dir_name=self._run_config.name, + storage_filesystem=self._run_config.storage_filesystem, + ) + + self._checkpoint_manager = CheckpointManager( + checkpoint_config=self._run_config.checkpoint_config, + storage_context=self._storage_context, + ) + report_handler = ReportCallbackHandler( + report_callbacks=( + [self._checkpoint_manager] + + [c for c in self._callbacks if isinstance(c, ReportCallback)] + ) + ) + + # Group callbacks by the hooks they're subscribed to. + self._controller_callbacks = [self._scaling_policy] + [ + c for c in self._callbacks if isinstance(c, ControllerCallback) + ] + # Group callbacks that will be propagated to the worker group, + # train worker and the train context. + worker_group_callbacks_to_propagate = [report_handler] + [ + c + for c in self._callbacks + if isinstance( + c, (WorkerGroupCallback, WorkerCallback, TrainContextCallback) + ) + ] + + self._worker_group = self.worker_group_cls( + train_run_context=self._train_run_context, + callbacks=worker_group_callbacks_to_propagate, + ) + self._state = TrainControllerState.INITIALIZING + + self._latest_poll_time = float("-inf") + self._health_check_interval_s = float( + os.getenv(HEALTH_CHECK_INTERVAL_S_ENV_VAR, DEFAULT_HEALTH_CHECK_INTERVAL_S) + ) + self._training_failed_error: Optional[TrainingFailedError] = None + + def _execute_scaling_decision( + self, decision: ScalingDecision, worker_group_status: WorkerGroupStatus + ): + """Executes scaling decisions.""" + for callback in self._controller_callbacks: + callback.before_controller_execute_scaling_decision( + decision, worker_group_status + ) + + if isinstance(decision, ResizeDecision): + self._restart_worker_group( + num_workers=decision.num_workers, + resources_per_worker=decision.resources_per_worker, + ) + + def _execute_failure_decision( + self, failure_decision: FailureDecision, worker_group_status: WorkerGroupStatus + ): + """Executes failure handling decisions (ex: restart, terminate).""" + assert worker_group_status.errors + + for callback in self._controller_callbacks: + callback.before_controller_execute_failure_decision( + failure_decision, worker_group_status + ) + + if failure_decision == FailureDecision.NOOP: + assert self._state == TrainControllerState.RUNNING + return + + errors_str = "\n".join( + [ + f"[Rank {worker_rank}]\n{error}" + for worker_rank, error in worker_group_status.errors.items() + ] + ) + + if failure_decision == FailureDecision.RESTART: + logger.error( + "Restarting training worker group after encountering " + f"failures on {len(worker_group_status.errors)} worker(s):\n" + f"{errors_str}" + ) + # Shutdown the worker group so that we don't keep polling errored tasks. + self._worker_group.shutdown() + self._set_state(TrainControllerState.RECOVERING) + elif failure_decision == FailureDecision.RAISE: + logger.error( + "Terminating training worker group after encountering " + f"failure(s) on {len(worker_group_status.errors)} worker(s):\n" + f"{errors_str}" + ) + self._set_state(TrainControllerState.ERRORED) + self._training_failed_error = TrainingFailedError( + worker_failures=worker_group_status.errors + ) + else: + raise ValueError(f"Unexpected failure decision: {failure_decision}") + + def _poll_workers(self) -> WorkerGroupStatus: + # Ensure that the time between polls is at least HEALTH_CHECK_INTERVAL_S. + time_since_last_poll = time_monotonic() - self._latest_poll_time + if time_since_last_poll < self._health_check_interval_s: + remaining_time = max( + self._health_check_interval_s - time_since_last_poll, 0 + ) + time.sleep(remaining_time) + + status = self._worker_group.poll_status(timeout=self._health_check_interval_s) + self._latest_poll_time = time_monotonic() + return status + + def _restart_worker_group(self, num_workers: int, resources_per_worker: dict): + """Restart the worker group and launch the train function.""" + self._worker_group.shutdown() + + # If there's a latest checkpoint that's been committed, + # use it to restore the worker group. + latest_checkpoint_result = self._checkpoint_manager.latest_checkpoint_result + latest_checkpoint = ( + latest_checkpoint_result.checkpoint if latest_checkpoint_result else None + ) + placement_strategy = self._scaling_policy.scaling_config.placement_strategy + + # Start the worker group with the latest checkpoint if there is one. + # Otherwise, start the worker group with the checkpoint set by controller. + # Finally, if there is no checkpoint, start the worker group with None. + try: + self._worker_group.start( + train_fn=self._train_fn, + num_workers=num_workers, + resources_per_worker=resources_per_worker, + placement_strategy=placement_strategy, + checkpoint=latest_checkpoint or self._resume_from_checkpoint, + ) + except (WorkerGroupStartupTimeoutError, WorkerGroupStartupFailedError) as e: + logger.error( + "Retrying the launch of the training worker group. " + f"The previous launch attempt encountered the following failure:\n{e}" + ) + + # TODO: Should this logic go through the failure policy? + # The current logic will always try recovering unconditionally + # on startup errors without a retry limit. + self._set_state(TrainControllerState.RECOVERING) + return + + # TODO: Consider starting the worker group asynchronously. + self._set_state(TrainControllerState.RUNNING) + + def _start(self): + for callback in self._controller_callbacks: + callback.after_controller_start() + + def _shutdown(self): + self._worker_group.shutdown() + + for callback in self._controller_callbacks: + callback.before_controller_shutdown() + + def get_worker_group(self) -> WorkerGroup: + return self._worker_group + + def get_state(self) -> TrainControllerState: + return self._state + + def _set_state(self, state: TrainControllerState): + previous_state = self._state + self._state = state + + for callback in self._controller_callbacks: + callback.after_controller_state_update(previous_state, state) + + def _run_control_loop_iteration(self): + """Run a single iteration of the control loop. + + Steps: + 1. Poll the worker group for status. + 2. If the worker group is initializing or recovering from an error, + make a scaling decision and execute it. + 3. If the worker group has finished, set the controller state to FINISHED. + 4. If the worker group has errors, make a failure decision and execute it. + 5. Otherwise, the worker group is running healthily. + Query the scaling policy for a scaling decision and execute it. + """ + assert self.get_state() in ( + TrainControllerState.RUNNING, + TrainControllerState.RECOVERING, + TrainControllerState.INITIALIZING, + ), self.get_state() + + worker_group_status = self._poll_workers() + + if worker_group_status.finished and not worker_group_status.errors: + self._set_state(TrainControllerState.FINISHED) + return + + if self.get_state() in ( + TrainControllerState.INITIALIZING, + TrainControllerState.RECOVERING, + ): + scaling_decision = ( + self._scaling_policy.make_decision_for_non_running_worker_group( + worker_group_status + ) + ) + self._execute_scaling_decision(scaling_decision, worker_group_status) + elif self.get_state() == TrainControllerState.RUNNING: + if worker_group_status.errors: + failure_decision = self._failure_policy.make_decision( + worker_group_status + ) + self._execute_failure_decision(failure_decision, worker_group_status) + else: + scaling_decision = ( + self._scaling_policy.make_decision_for_running_worker_group( + worker_group_status + ) + ) + self._execute_scaling_decision(scaling_decision, worker_group_status) + + @wrap_auto_init + def run(self): + """Run the main control loop. Exits when training is finished or errored.""" + self._start() + + while self.get_state() not in ( + TrainControllerState.ERRORED, + TrainControllerState.FINISHED, + ): + self._run_control_loop_iteration() + + self._shutdown() + + def get_result(self) -> Result: + """Get the final training result from the TrainController.""" + + controller_state = self.get_state() + if controller_state not in ( + TrainControllerState.FINISHED, + TrainControllerState.ERRORED, + ): + raise ValueError( + f"Cannot get result when controller is in state {controller_state}" + ) + + latest_checkpoint_result = self._checkpoint_manager.latest_checkpoint_result + latest_metrics = ( + latest_checkpoint_result.metrics if latest_checkpoint_result else None + ) + latest_checkpoint = ( + latest_checkpoint_result.checkpoint if latest_checkpoint_result else None + ) + best_checkpoints = [ + (r.checkpoint, r.metrics) + for r in self._checkpoint_manager.best_checkpoint_results + ] + storage_filesystem, storage_fs_path = get_fs_and_path( + self._run_config.storage_path, self._run_config.storage_filesystem + ) + experiment_fs_path = Path(storage_fs_path, self._run_config.name).as_posix() + + return Result( + metrics=latest_metrics, + checkpoint=latest_checkpoint, + error=self._training_failed_error, + path=experiment_fs_path, + best_checkpoints=best_checkpoints, + _storage_filesystem=storage_filesystem, + ) diff --git a/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/failure_handling/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/failure_handling/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c527f6b0e2ddc612abfefd11e24c5a5f3242263e Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/failure_handling/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/failure_handling/__pycache__/default.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/failure_handling/__pycache__/default.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..20ceeeabd00a1c074c365bfdd6756507f6be7bd2 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/failure_handling/__pycache__/default.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/failure_handling/__pycache__/factory.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/failure_handling/__pycache__/factory.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..082d27034ba270d0702a7c1211be17a746bf941f Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/failure_handling/__pycache__/factory.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/failure_handling/__pycache__/failure_policy.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/failure_handling/__pycache__/failure_policy.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fe9c354ca421b55bb326c937f1c7b62635a17c63 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/failure_handling/__pycache__/failure_policy.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/failure_handling/default.py b/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/failure_handling/default.py new file mode 100644 index 0000000000000000000000000000000000000000..6cb4457a957a551bb6364f8f9bd5749fec366645 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/failure_handling/default.py @@ -0,0 +1,44 @@ +import logging + +from ray.train import FailureConfig +from ray.train.v2._internal.execution.failure_handling import ( + FailureDecision, + FailurePolicy, +) +from ray.train.v2._internal.execution.worker_group import WorkerGroupStatus + +logger = logging.getLogger(__name__) + + +class DefaultFailurePolicy(FailurePolicy): + def __init__(self, failure_config: FailureConfig): + super().__init__(failure_config) + self._total_failures = 0 + + def make_decision(self, worker_group_status: WorkerGroupStatus) -> FailureDecision: + if not worker_group_status.errors: + return FailureDecision.NOOP + + self._total_failures += 1 + + if self.failure_config.max_failures == -1: + logger.info( + "Deciding to RESTART, since infinite retry is enabled. " + f"Encountered {self._total_failures} failures so far." + ) + return FailureDecision.RESTART + + if self._total_failures > self.failure_config.max_failures: + logger.info( + "Deciding to TERMINATE, since the total failure count " + f"({self._total_failures}) exceeded the maximum allowed failures: " + f"FailureConfig(max_failures={self.failure_config.max_failures})." + ) + return FailureDecision.RAISE + + logger.info( + "Deciding to RESTART, since the total " + f"failure count ({self._total_failures}) <= " + f"FailureConfig(max_failures={self.failure_config.max_failures})." + ) + return FailureDecision.RESTART diff --git a/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/failure_handling/factory.py b/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/failure_handling/factory.py new file mode 100644 index 0000000000000000000000000000000000000000..48a902c0300e7931b4c1808ba90766fdaf1bff98 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/failure_handling/factory.py @@ -0,0 +1,13 @@ +from ray.train import FailureConfig +from ray.train.v2._internal.execution.failure_handling import ( + DefaultFailurePolicy, + FailurePolicy, +) + + +def create_failure_policy(failure_config: FailureConfig) -> FailurePolicy: + """Create a failure policy from the given failure config. + + Defaults to the `DefaultFailurePolicy` implementation. + """ + return DefaultFailurePolicy(failure_config=failure_config) diff --git a/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/scaling_policy/__init__.py b/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/scaling_policy/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f51407285701dbce7bc8ba3d4d93c6b0afb8bd5c --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/scaling_policy/__init__.py @@ -0,0 +1,19 @@ +# isort: off +from .scaling_policy import ScalingDecision, ScalingPolicy, NoopDecision, ResizeDecision +from .fixed import FixedScalingPolicy +from .factory import create_scaling_policy + +# isort: on + + +__all__ = [ + "ScalingPolicy", + "FixedScalingPolicy", + "ScalingDecision", + "NoopDecision", + "ResizeDecision", + "create_scaling_policy", +] + + +# DO NOT ADD ANYTHING AFTER THIS LINE. diff --git a/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/scaling_policy/__pycache__/factory.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/scaling_policy/__pycache__/factory.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7c73c81c7248d2695309ae8fdd14dbc98d3fdddb Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/scaling_policy/__pycache__/factory.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/scaling_policy/__pycache__/fixed.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/scaling_policy/__pycache__/fixed.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0cf3090a3ad894c00abeffcef045b34ef74dce51 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/scaling_policy/__pycache__/fixed.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/scaling_policy/__pycache__/scaling_policy.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/scaling_policy/__pycache__/scaling_policy.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1ba39a116360f1900f466126cef0600de8225f45 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/scaling_policy/__pycache__/scaling_policy.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/scaling_policy/factory.py b/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/scaling_policy/factory.py new file mode 100644 index 0000000000000000000000000000000000000000..506a1df09e94e35f093eb9e5802f714ec7ed1543 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/scaling_policy/factory.py @@ -0,0 +1,13 @@ +from ray.train.v2._internal.execution.scaling_policy import ( + FixedScalingPolicy, + ScalingPolicy, +) +from ray.train.v2.api.config import ScalingConfig + + +def create_scaling_policy(scaling_config: ScalingConfig) -> ScalingPolicy: + """Create a scaling policy from the given scaling config. + + Defaults to the `FixedScalingPolicy` implementation. + """ + return FixedScalingPolicy(scaling_config=scaling_config) diff --git a/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/scaling_policy/fixed.py b/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/scaling_policy/fixed.py new file mode 100644 index 0000000000000000000000000000000000000000..e3263250ad43741b92efe0384a4bf437d0784052 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/scaling_policy/fixed.py @@ -0,0 +1,22 @@ +from ray.train.v2._internal.execution.scaling_policy import ( + NoopDecision, + ResizeDecision, + ScalingDecision, + ScalingPolicy, +) +from ray.train.v2._internal.execution.worker_group import WorkerGroupStatus + + +class FixedScalingPolicy(ScalingPolicy): + def make_decision_for_non_running_worker_group( + self, worker_group_status: WorkerGroupStatus + ) -> ScalingDecision: + return ResizeDecision( + num_workers=self.scaling_config.num_workers, + resources_per_worker=self.scaling_config._resources_per_worker_not_none, + ) + + def make_decision_for_running_worker_group( + self, worker_group_status: WorkerGroupStatus + ) -> ScalingDecision: + return NoopDecision() diff --git a/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/scaling_policy/scaling_policy.py b/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/scaling_policy/scaling_policy.py new file mode 100644 index 0000000000000000000000000000000000000000..8f2774a514ae127190632d991d3b7f6b29417db7 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/scaling_policy/scaling_policy.py @@ -0,0 +1,51 @@ +import abc +from dataclasses import dataclass +from typing import Dict + +from ray.train.v2._internal.execution.callback import ControllerCallback +from ray.train.v2._internal.execution.worker_group import WorkerGroupStatus +from ray.train.v2.api.config import ScalingConfig + + +@dataclass +class ScalingDecision: + pass + + +@dataclass +class NoopDecision(ScalingDecision): + pass + + +@dataclass +class ResizeDecision(ScalingDecision): + num_workers: int + resources_per_worker: Dict[str, float] + + +class ScalingPolicy(abc.ABC, ControllerCallback): + """A policy that determines when and how to scale a worker group. + + This can be used to implement elasticity and fault tolerance. + + Recovery decisions are made when workers are in an inactive or unhealthy state. + Upscale decisions are optional and are made when workers are healthy. + """ + + def __init__(self, scaling_config: ScalingConfig): + self.scaling_config = scaling_config + + @abc.abstractmethod + def make_decision_for_non_running_worker_group( + self, worker_group_status: WorkerGroupStatus + ) -> ScalingDecision: + """Makes a scaling decision when the worker group is initializing + or recovering from an error.""" + raise NotImplementedError + + @abc.abstractmethod + def make_decision_for_running_worker_group( + self, worker_group_status: WorkerGroupStatus + ) -> ScalingDecision: + """Makes a scaling decision when monitoring healthy, running workers.""" + raise NotImplementedError diff --git a/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/storage.py b/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/storage.py new file mode 100644 index 0000000000000000000000000000000000000000..ef3c8b4e293b996cf36a0b4eb1f90bd6b5176581 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/storage.py @@ -0,0 +1,551 @@ +# Try import ray[train] core requirements (defined in setup.py) +# isort: off +try: + import fsspec # noqa + from fsspec.implementations.local import LocalFileSystem + +except (ImportError, ModuleNotFoundError) as e: + raise RuntimeError( + "fsspec is a required dependency of Ray Train and Ray Tune. " + "Please install with: `pip install fsspec`" + ) from e + +try: + import pyarrow + import pyarrow.fs + +except (ImportError, ModuleNotFoundError) as e: + raise RuntimeError( + "pyarrow is a required dependency of Ray Train and Ray Tune. " + "Please install with: `pip install pyarrow`" + ) from e +# isort: on + +import fnmatch +import logging +import os +import shutil +from pathlib import Path +from typing import TYPE_CHECKING, Callable, List, Optional, Tuple, Type, Union + +from ray.air._internal.filelock import TempFileLock +from ray.train.constants import _get_ray_train_session_dir +from ray.train.v2._internal.constants import ( + CHECKPOINT_MANAGER_SNAPSHOT_FILENAME, + VALIDATE_STORAGE_MARKER_FILENAME, +) +from ray.train.v2._internal.util import date_str +from ray.util.annotations import DeveloperAPI + +if TYPE_CHECKING: + from ray.train import Checkpoint + + +logger = logging.getLogger(__name__) + + +class _ExcludingLocalFilesystem(LocalFileSystem): + """LocalFileSystem wrapper to exclude files according to patterns. + + Args: + root_path: Root path to strip when matching with the exclude pattern. + Ex: root_path="/tmp/a/b/c", exclude=["*a*"], will exclude + /tmp/a/b/c/_a_.txt but not ALL of /tmp/a/*. + exclude: List of patterns that are applied to files returned by + ``self.find()``. If a file path matches this pattern, it will + be excluded. + + """ + + def __init__(self, root_path: Path, exclude: List[str], **kwargs): + super().__init__(**kwargs) + self._exclude = exclude + self._root_path = root_path + + @property + def fsid(self): + return "_excluding_local" + + def _should_exclude(self, path: str) -> bool: + """Return True if `path` (relative to `root_path`) matches any of the + `self._exclude` patterns.""" + path = Path(path) + relative_path = path.relative_to(self._root_path).as_posix() + match_candidates = [relative_path] + if path.is_dir(): + # Everything is in posix path format ('/') + match_candidates.append(relative_path + "/") + + for excl in self._exclude: + if any(fnmatch.fnmatch(candidate, excl) for candidate in match_candidates): + return True + return False + + def find(self, path, maxdepth=None, withdirs=False, detail=False, **kwargs): + """Call parent find() and exclude from result.""" + paths = super().find( + path, maxdepth=maxdepth, withdirs=withdirs, detail=detail, **kwargs + ) + if detail: + return { + path: out + for path, out in paths.items() + if not self._should_exclude(path) + } + else: + return [path for path in paths if not self._should_exclude(path)] + + +def _pyarrow_fs_copy_files( + source, destination, source_filesystem=None, destination_filesystem=None, **kwargs +): + if isinstance(destination_filesystem, pyarrow.fs.S3FileSystem): + # Workaround multi-threading issue with pyarrow. Note that use_threads=True + # is safe for download, just not for uploads, see: + # https://github.com/apache/arrow/issues/32372 + kwargs.setdefault("use_threads", False) + + # Use a large chunk size to speed up large checkpoint transfers. + kwargs.setdefault("chunk_size", 64 * 1024 * 1024) + + return pyarrow.fs.copy_files( + source, + destination, + source_filesystem=source_filesystem, + destination_filesystem=destination_filesystem, + **kwargs, + ) + + +# TODO(justinvyu): Add unit tests for all these utils. + + +def _delete_fs_path(fs: pyarrow.fs.FileSystem, fs_path: str): + is_dir = _is_directory(fs, fs_path) + + try: + if is_dir: + fs.delete_dir(fs_path) + else: + fs.delete_file(fs_path) + except Exception: + logger.exception(f"Caught exception when deleting path at ({fs}, {fs_path}):") + + +def _download_from_fs_path( + fs: pyarrow.fs.FileSystem, + fs_path: str, + local_path: str, + filelock: bool = True, +): + """Downloads a directory or file from (fs, fs_path) to a local path. + + If fs_path points to a directory: + - The full directory contents are downloaded directly into `local_path`, + rather than to a subdirectory of `local_path`. + + If fs_path points to a file: + - The file is downloaded to `local_path`, which is expected to be a file path. + + If the download fails, the `local_path` contents are + cleaned up before raising, if the directory did not previously exist. + + NOTE: This method creates `local_path`'s parent directories if they do not + already exist. If the download fails, this does NOT clean up all the parent + directories that were created. + + Args: + fs: The filesystem to download from. + fs_path: The filesystem path (either a directory or a file) to download. + local_path: The local path to download to. + filelock: Whether to require a file lock before downloading, useful for + multiple downloads to the same directory that may be happening in parallel. + + Raises: + FileNotFoundError: if (fs, fs_path) doesn't exist. + """ + + _local_path = Path(local_path).resolve() + exists_before = _local_path.exists() + if _is_directory(fs=fs, fs_path=fs_path): + _local_path.mkdir(parents=True, exist_ok=True) + else: + _local_path.parent.mkdir(parents=True, exist_ok=True) + + try: + if filelock: + with TempFileLock(f"{os.path.normpath(local_path)}.lock"): + _pyarrow_fs_copy_files(fs_path, local_path, source_filesystem=fs) + else: + _pyarrow_fs_copy_files(fs_path, local_path, source_filesystem=fs) + except Exception as e: + # Clean up the directory if downloading was unsuccessful + if not exists_before: + shutil.rmtree(local_path, ignore_errors=True) + raise e + + +def _upload_to_fs_path( + local_path: str, + fs: pyarrow.fs.FileSystem, + fs_path: str, + exclude: Optional[List[str]] = None, +) -> None: + """Uploads a local directory or file to (fs, fs_path). + + NOTE: This will create all necessary parent directories at the destination. + + Args: + local_path: The local path to upload. + fs: The filesystem to upload to. + fs_path: The filesystem path where the dir/file will be uploaded to. + exclude: A list of filename matches to exclude from upload. This includes + all files under subdirectories as well. + This pattern will match with the relative paths of all files under + `local_path`. + Ex: ["*.png"] to exclude all .png images. + """ + + if not exclude: + # TODO(justinvyu): uploading a single file doesn't work + # (since we always create a directory at fs_path) + _create_directory(fs=fs, fs_path=fs_path) + _pyarrow_fs_copy_files(local_path, fs_path, destination_filesystem=fs) + return + + _upload_to_uri_with_exclude_fsspec( + local_path=local_path, fs=fs, fs_path=fs_path, exclude=exclude + ) + + +def _upload_to_uri_with_exclude_fsspec( + local_path: str, fs: "pyarrow.fs", fs_path: str, exclude: Optional[List[str]] +) -> None: + local_fs = _ExcludingLocalFilesystem(root_path=local_path, exclude=exclude) + handler = pyarrow.fs.FSSpecHandler(local_fs) + source_fs = pyarrow.fs.PyFileSystem(handler) + + _create_directory(fs=fs, fs_path=fs_path) + _pyarrow_fs_copy_files( + local_path, fs_path, source_filesystem=source_fs, destination_filesystem=fs + ) + + +def _list_at_fs_path( + fs: pyarrow.fs.FileSystem, + fs_path: str, + file_filter: Callable[[pyarrow.fs.FileInfo], bool] = lambda x: True, +) -> List[str]: + """Returns the list of filenames at (fs, fs_path), similar to os.listdir. + + If the path doesn't exist, returns an empty list. + """ + selector = pyarrow.fs.FileSelector(fs_path, allow_not_found=True, recursive=False) + return [ + os.path.relpath(file_info.path.lstrip("/"), start=fs_path.lstrip("/")) + for file_info in fs.get_file_info(selector) + if file_filter(file_info) + ] + + +def _exists_at_fs_path(fs: pyarrow.fs.FileSystem, fs_path: str) -> bool: + """Returns True if (fs, fs_path) exists.""" + + valid = fs.get_file_info(fs_path) + return valid.type != pyarrow.fs.FileType.NotFound + + +def _is_directory(fs: pyarrow.fs.FileSystem, fs_path: str) -> bool: + """Checks if (fs, fs_path) is a directory or a file. + + Raises: + FileNotFoundError: if (fs, fs_path) doesn't exist. + """ + + file_info = fs.get_file_info(fs_path) + if file_info.type == pyarrow.fs.FileType.NotFound: + raise FileNotFoundError(f"Path not found: ({fs}, {fs_path})") + + return not file_info.is_file + + +def _create_directory(fs: pyarrow.fs.FileSystem, fs_path: str) -> None: + """Create directory at (fs, fs_path). + + Some external filesystems require directories to already exist, or at least + the `netloc` to be created (e.g. PyArrows ``mock://`` filesystem). + + Generally this should be done before and outside of Ray applications. This + utility is thus primarily used in testing, e.g. of ``mock://` URIs. + """ + try: + fs.create_dir(fs_path) + except Exception: + logger.exception( + f"Caught exception when creating directory at ({fs}, {fs_path}):" + ) + + +def get_fs_and_path( + storage_path: Union[str, os.PathLike], + storage_filesystem: Optional[pyarrow.fs.FileSystem] = None, +) -> Tuple[pyarrow.fs.FileSystem, str]: + """Returns the fs and path from a storage path and an optional custom fs. + + Args: + storage_path: A storage path or URI. (ex: s3://bucket/path or /tmp/ray_results) + storage_filesystem: A custom filesystem to use. If not provided, + this will be auto-resolved by pyarrow. If provided, the storage_path + is assumed to be prefix-stripped already, and must be a valid path + on the filesystem. + """ + storage_path = str(storage_path) + + if storage_filesystem: + return storage_filesystem, storage_path + + return pyarrow.fs.FileSystem.from_uri(storage_path) + + +@DeveloperAPI +class StorageContext: + """Shared context that holds the source of truth for all paths and + storage utilities, passed along from the driver to workers. + + This object defines a few types of paths: + 1. *_fs_path: A path on the `storage_filesystem`. This is a regular path + which has been prefix-stripped by pyarrow.fs.FileSystem.from_uri and + can be joined with `Path(...).as_posix()`. + 2. *_driver_staging_path: The temporary staging directory on the local filesystem + where driver artifacts are saved to before persisting them to storage. + 3. trial_working_directory: The local filesystem path that the remote + actors' working directories are moved to by default. + This is separated from the driver staging path so that driver syncing + does not implicitly upload the trial working directory, for trials on the + driver node. + + Example with storage_path="mock:///bucket/path?param=1": + + >>> import ray + >>> from ray.train._internal.storage import StorageContext + >>> import os + >>> _ = ray.init() + >>> storage = StorageContext( + ... storage_path="mock://netloc/bucket/path?param=1", + ... experiment_dir_name="exp_name", + ... ) + >>> storage.storage_filesystem # Auto-resolved # doctest: +ELLIPSIS + >> storage.experiment_fs_path + 'bucket/path/exp_name' + >>> storage.experiment_driver_staging_path # doctest: +ELLIPSIS + '/tmp/ray/session_.../artifacts/.../exp_name/driver_artifacts' + >>> storage.trial_dir_name = "trial_dir" + >>> storage.trial_fs_path + 'bucket/path/exp_name/trial_dir' + >>> storage.trial_driver_staging_path # doctest: +ELLIPSIS + '/tmp/ray/session_.../artifacts/.../exp_name/driver_artifacts/trial_dir' + >>> storage.trial_working_directory # doctest: +ELLIPSIS + '/tmp/ray/session_.../artifacts/.../exp_name/working_dirs/trial_dir' + >>> ray.shutdown() + + Example with storage_path="/tmp/ray_results": + + >>> from ray.train._internal.storage import StorageContext + >>> storage = StorageContext( + ... storage_path="/tmp/ray_results", + ... experiment_dir_name="exp_name", + ... ) + >>> storage.storage_fs_path + '/tmp/ray_results' + >>> storage.experiment_fs_path + '/tmp/ray_results/exp_name' + >>> storage.storage_filesystem # Auto-resolved # doctest: +ELLIPSIS + " + ) + + def _create_validation_file(self): + """On the creation of a storage context, create a validation file at the + storage path to verify that the storage path can be written to. + This validation file is also used to check whether the storage path is + accessible by all nodes in the cluster.""" + valid_file = Path( + self.experiment_fs_path, VALIDATE_STORAGE_MARKER_FILENAME + ).as_posix() + self.storage_filesystem.create_dir(self.experiment_fs_path) + with self.storage_filesystem.open_output_stream(valid_file): + pass + + def _check_validation_file(self): + """Checks that the validation file exists at the storage path.""" + valid_file = Path( + self.experiment_fs_path, VALIDATE_STORAGE_MARKER_FILENAME + ).as_posix() + if not _exists_at_fs_path(fs=self.storage_filesystem, fs_path=valid_file): + raise RuntimeError( + f"Unable to set up cluster storage with the following settings:\n{self}" + "\nCheck that all nodes in the cluster have read/write access " + "to the configured storage path. `RunConfig(storage_path)` should be " + "set to a cloud storage URI or a shared filesystem path accessible " + "by all nodes in your cluster ('s3://bucket' or '/mnt/nfs'). " + "A local path on the head node is not accessible by worker nodes. " + "See: https://docs.ray.io/en/latest/train/user-guides/persistent-storage.html" # noqa: E501 + ) + + def persist_current_checkpoint( + self, checkpoint: "Checkpoint", checkpoint_dir_name: str + ) -> "Checkpoint": + """Persists a given checkpoint to the current checkpoint path on the filesystem. + + This method copies the checkpoint files to the storage location. + It's up to the user to delete the original checkpoint files if desired. + + For example, the original directory is typically a local temp directory. + + Args: + checkpoint: The checkpoint to persist to + (fs, experiment_fs_path / checkpoint_dir_name). + + Returns: + Checkpoint: A Checkpoint pointing to the persisted checkpoint location. + """ + # TODO(justinvyu): Fix this cyclical import. + from ray.train import Checkpoint + + checkpoint_fs_path = self.build_checkpoint_path_from_name(checkpoint_dir_name) + + logger.debug( + "Copying checkpoint files to storage path:\n" + "({source_fs}, {source}) -> ({dest_fs}, {destination})".format( + source=checkpoint.path, + destination=checkpoint_fs_path, + source_fs=checkpoint.filesystem, + dest_fs=self.storage_filesystem, + ) + ) + + # Raise an error if the storage path is not accessible when + # attempting to upload a checkpoint from a remote worker. + # Ex: If storage_path is a local path, then a validation marker + # will only exist on the head node but not the worker nodes. + self._check_validation_file() + + self.storage_filesystem.create_dir(checkpoint_fs_path) + _pyarrow_fs_copy_files( + source=checkpoint.path, + destination=checkpoint_fs_path, + source_filesystem=checkpoint.filesystem, + destination_filesystem=self.storage_filesystem, + ) + + persisted_checkpoint = Checkpoint( + filesystem=self.storage_filesystem, + path=checkpoint_fs_path, + ) + logger.info(f"Checkpoint successfully created at: {persisted_checkpoint}") + return persisted_checkpoint + + @property + def experiment_fs_path(self) -> str: + """The path on the `storage_filesystem` to the experiment directory. + + NOTE: This does not have a URI prefix anymore, since it has been stripped + by pyarrow.fs.FileSystem.from_uri already. The URI scheme information is + kept in `storage_filesystem` instead. + """ + return Path(self.storage_fs_path, self.experiment_dir_name).as_posix() + + @property + def local_working_directory(self) -> str: + """Every ray train worker will set this directory as its working directory.""" + if self.experiment_dir_name is None: + raise RuntimeError( + "Cannot access `local_working_directory` without " + "setting `experiment_dir_name`" + ) + return Path(_get_ray_train_session_dir(), self.experiment_dir_name).as_posix() + + @property + def checkpoint_manager_snapshot_path(self) -> str: + """The path to the checkpoint manager snapshot file.""" + return Path( + self.experiment_fs_path, CHECKPOINT_MANAGER_SNAPSHOT_FILENAME + ).as_posix() + + @staticmethod + def get_experiment_dir_name(run_obj: Union[str, Callable, Type]) -> str: + from ray.tune.experiment import Experiment + + run_identifier = Experiment.get_trainable_name(run_obj) + + if bool(int(os.environ.get("TUNE_DISABLE_DATED_SUBDIR", 0))): + dir_name = run_identifier + else: + dir_name = "{}_{}".format(run_identifier, date_str()) + return dir_name + + @staticmethod + def make_default_checkpoint_dir_name(): + """Get the name of the checkpoint directory by timestamp.""" + return f"checkpoint_{date_str(include_ms=True)}" + + def extract_checkpoint_dir_name_from_path(self, checkpoint_path: str) -> str: + """Get the checkpoint name from the checkpoint path. + The parent directory of the checkpoint path should be the experiment directory. + """ + # TODO: Use Pathlib to extract the name when supports at least Python 3.9 + experiment_fs_path = self.experiment_fs_path + "/" + if not checkpoint_path.startswith(experiment_fs_path): + raise ValueError( + f"Checkpoint path {checkpoint_path} is not under the experiment " + f"directory {self.experiment_fs_path}." + ) + return checkpoint_path[len(experiment_fs_path) :] + + def build_checkpoint_path_from_name(self, checkpoint_name: str) -> str: + """Get the checkpoint path from the checkpoint name. + The parent directory of the checkpoint path should be the experiment directory. + """ + return Path(self.experiment_fs_path, checkpoint_name).as_posix() diff --git a/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/worker_group/__pycache__/worker.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/worker_group/__pycache__/worker.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f14c25c0f777f44f63f6e038b8fbd04cb819dfe0 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/worker_group/__pycache__/worker.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/worker_group/thread_runner.py b/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/worker_group/thread_runner.py new file mode 100644 index 0000000000000000000000000000000000000000..fce7f4d550238fb27f151c5ad8ce0c3124065723 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/worker_group/thread_runner.py @@ -0,0 +1,73 @@ +import threading +import traceback +from typing import Callable, Optional, TypeVar + +from ray.train.v2._internal.exceptions import UserExceptionWithTraceback + +T = TypeVar("T") + + +class ThreadRunner: + """Utility to run a user function as a thread and capture its return value + or exception. + """ + + def __init__(self): + self._ret: Optional[T] = None + self._exc: Optional[UserExceptionWithTraceback] = None + + self._thread: Optional[threading.Thread] = None + self._lock = threading.Lock() + + self._is_running = False + + def run(self, target: Callable[[], T]) -> None: + if self._thread is not None: + raise RuntimeError("Thread is already running.") + + def _run_target(): + with self._lock: + self._is_running = True + + try: + result = target() + with self._lock: + self._ret = result + except BaseException as e: + with self._lock: + # Exclude the the first 2 frames from the traceback, which are + # the `ThreadRunner._run_target` and `construct_train_func` calls. + # TODO(justinvyu): This is brittle and may break if the call stack + # changes. Figure out a more robust way to exclude these frames. + exc_traceback_str = traceback.format_exc( + limit=-(len(traceback.extract_tb(e.__traceback__)) - 2) + ) + self._exc = UserExceptionWithTraceback( + e, traceback_str=exc_traceback_str + ) + + with self._lock: + self._is_running = False + + self._thread = threading.Thread(target=_run_target, daemon=True) + self._thread.start() + + def is_running(self) -> bool: + with self._lock: + return self._is_running + + def get_error(self) -> Optional[BaseException]: + with self._lock: + return self._exc + + def get_return_value(self) -> Optional[T]: + with self._lock: + return self._ret + + def join(self, timeout: Optional[float] = None) -> T: + if self._thread is None: + raise RuntimeError("Must call `run` before trying to `join`.") + + self._thread.join(timeout=timeout) + + return self.get_return_value() diff --git a/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/util.py b/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/util.py new file mode 100644 index 0000000000000000000000000000000000000000..23c48717da18cd7861b768615ea747186a51a8bb --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/train/v2/_internal/util.py @@ -0,0 +1,171 @@ +import contextlib +import functools +import time +from datetime import datetime +from typing import ( + Any, + Callable, + ContextManager, + Dict, + Generator, + List, + Optional, + TypeVar, + Union, +) + +import ray +from ray.train._internal.utils import count_required_parameters +from ray.types import ObjectRef + +T = TypeVar("T") + + +def bundle_to_remote_args(bundle: dict) -> dict: + """Convert a bundle of resources to Ray actor/task arguments. + + >>> bundle_to_remote_args({"GPU": 1, "memory": 1, "custom": 0.1}) + {'num_cpus': 0, 'num_gpus': 1, 'memory': 1, 'resources': {'custom': 0.1}} + """ + bundle = bundle.copy() + args = { + "num_cpus": bundle.pop("CPU", 0), + "num_gpus": bundle.pop("GPU", 0), + "memory": bundle.pop("memory", 0), + } + if bundle: + args["resources"] = bundle + return args + + +def construct_train_func( + train_func: Union[Callable[[], T], Callable[[Dict[str, Any]], T]], + config: Optional[Dict[str, Any]], + train_func_context: ContextManager, + fn_arg_name: Optional[str] = "train_func", +) -> Callable[[], T]: + """Validates and constructs the training function to execute. + Args: + train_func: The training function to execute. + This can either take in no arguments or a ``config`` dict. + config (Optional[Dict]): Configurations to pass into + ``train_func``. If None then an empty Dict will be created. + train_func_context: Context manager for user's `train_func`, which executes + backend-specific logic before and after the training function. + fn_arg_name (Optional[str]): The name of training function to use for error + messages. + Returns: + A valid training function. + Raises: + ValueError: if the input ``train_func`` is invalid. + """ + num_required_params = count_required_parameters(train_func) + + if num_required_params > 1: + err_msg = ( + f"{fn_arg_name} should take in 0 or 1 required arguments, but it accepts " + f"{num_required_params} required arguments instead." + ) + raise ValueError(err_msg) + + if num_required_params == 1: + config = config or {} + + @functools.wraps(train_func) + def train_fn(): + with train_func_context(): + return train_func(config) + + else: # num_params == 0 + + @functools.wraps(train_func) + def train_fn(): + with train_func_context(): + return train_func() + + return train_fn + + +def date_str(include_ms: bool = False): + pattern = "%Y-%m-%d_%H-%M-%S" + if include_ms: + pattern += ".%f" + return datetime.today().strftime(pattern) + + +def time_monotonic(): + return time.monotonic() + + +def _copy_doc(copy_func): + def wrapped(func): + func.__doc__ = copy_func.__doc__ + return func + + return wrapped + + +def ray_get_safe( + object_refs: Union[ObjectRef, List[ObjectRef]] +) -> Union[Any, List[Any]]: + """This is a safe version of `ray.get` that raises an exception immediately + if an input task dies, while the others are still running. + + TODO(ml-team, core-team): This is NOT a long-term solution, + and we should not maintain this function indefinitely. + This is a mitigation for a Ray Core bug, and should be removed when + that is fixed. + See here: https://github.com/ray-project/ray/issues/47204 + + Args: + object_refs: A single or list of object refs to wait on. + + Returns: + task_outputs: The outputs of the tasks. + + Raises: + `RayTaskError`/`RayActorError`: if any of the tasks encounter a runtime error + or fail due to actor/task death (ex: node failure). + """ + is_list = isinstance(object_refs, list) + object_refs = object_refs if is_list else [object_refs] + + unready = object_refs + task_to_output = {} + while unready: + ready, unready = ray.wait(unready, num_returns=1) + if ready: + for task, task_output in zip(ready, ray.get(ready)): + task_to_output[task] = task_output + + assert len(task_to_output) == len(object_refs) + ordered_outputs = [task_to_output[task] for task in object_refs] + return ordered_outputs if is_list else ordered_outputs[0] + + +@contextlib.contextmanager +def invoke_context_managers( + context_managers: List[ContextManager], +) -> Generator[None, None, None]: + """ + Utility to invoke a list of context managers and yield sequentially. + + Args: + context_managers: List of context managers to invoke. + """ + with contextlib.ExitStack() as stack: + for context_manager in context_managers: + stack.enter_context(context_manager()) + yield + + +def get_module_name(obj: object) -> str: + """Returns the full module name of the given object, including its qualified name. + + Args: + obj: The object (class, function, etc.) whose module name is required. + + Returns: + Full module and qualified name as a string. + """ + return f"{obj.__module__}.{obj.__qualname__}" diff --git a/.venv/lib/python3.11/site-packages/ray/train/v2/horovod/__init__.py b/.venv/lib/python3.11/site-packages/ray/train/v2/horovod/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/ray/train/v2/horovod/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/train/v2/horovod/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ab3c7bfa680660ba96d456f91bc3f313371460d8 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/train/v2/horovod/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/train/v2/horovod/__pycache__/horovod_trainer.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/train/v2/horovod/__pycache__/horovod_trainer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..66f317200e9c0dac7bae48e2fbea06688a77bc11 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/train/v2/horovod/__pycache__/horovod_trainer.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/train/v2/horovod/horovod_trainer.py b/.venv/lib/python3.11/site-packages/ray/train/v2/horovod/horovod_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..10fffd2bd846bcdaea85e62fa881ccd43599b7d0 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/train/v2/horovod/horovod_trainer.py @@ -0,0 +1,37 @@ +from typing import Any, Callable, Dict, Optional, Union + +from ray.air.config import RunConfig, ScalingConfig +from ray.train import Checkpoint, DataConfig +from ray.train.data_parallel_trainer import DataParallelTrainer +from ray.train.horovod.config import HorovodConfig +from ray.train.trainer import GenDataset +from ray.util.annotations import Deprecated + + +@Deprecated +class HorovodTrainer(DataParallelTrainer): + """A Trainer for data parallel Horovod training. + + Horovod Trainer is Deprecated. + """ + + def __init__( + self, + train_loop_per_worker: Union[Callable[[], None], Callable[[Dict], None]], + *, + train_loop_config: Optional[Dict] = None, + horovod_config: Optional[HorovodConfig] = None, + scaling_config: Optional[ScalingConfig] = None, + dataset_config: Optional[DataConfig] = None, + run_config: Optional[RunConfig] = None, + datasets: Optional[Dict[str, GenDataset]] = None, + metadata: Optional[Dict[str, Any]] = None, + resume_from_checkpoint: Optional[Checkpoint] = None, + ): + raise DeprecationWarning( + "`HorovodTrainer` is not supported and is scheduled to be removed " + "in the future. " + "Please consider using `TorchTrainer` or `TensorflowTrainer`, " + "fall back to the old implementation with `RAY_TRAIN_V2_ENABLED=0`, " + "or file an issue on Github describing your use case." + ) diff --git a/.venv/lib/python3.11/site-packages/ray/train/v2/lightgbm/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/train/v2/lightgbm/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..156aad7e83ee0465307e4eb46d168dc14b745607 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/train/v2/lightgbm/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/train/v2/lightgbm/__pycache__/lightgbm_trainer.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/train/v2/lightgbm/__pycache__/lightgbm_trainer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4138af8337a333c770189f763bdcb6b34cbcf7b3 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/train/v2/lightgbm/__pycache__/lightgbm_trainer.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/train/v2/lightning/__init__.py b/.venv/lib/python3.11/site-packages/ray/train/v2/lightning/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/ray/train/v2/lightning/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/train/v2/lightning/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5e51ada2955dc858749ad6b630d12b0f804c498d Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/train/v2/lightning/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/train/v2/lightning/__pycache__/lightning_utils.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/train/v2/lightning/__pycache__/lightning_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a2aca6df9cdf7917f739d817f97f39496135130d Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/train/v2/lightning/__pycache__/lightning_utils.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/train/v2/lightning/lightning_utils.py b/.venv/lib/python3.11/site-packages/ray/train/v2/lightning/lightning_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..008107ba7a0baf053e1a144e2a034a635ed04055 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/train/v2/lightning/lightning_utils.py @@ -0,0 +1,58 @@ +import os +import shutil +import tempfile +from pathlib import Path + +import ray.train +from ray._private.usage.usage_lib import TagKey, record_extra_usage_tag +from ray.train.lightning._lightning_utils import ( + RayTrainReportCallback as RayTrainReportCallbackV1, +) +from ray.train.lightning._lightning_utils import import_lightning +from ray.util import PublicAPI + +pl = import_lightning() + + +@PublicAPI(stability="beta") +class RayTrainReportCallback(RayTrainReportCallbackV1): + """A simple callback that reports checkpoints to Ray on train epoch end. + + This callback is a subclass of `lightning.pytorch.callbacks.Callback + `_. + + It fetches the latest `trainer.callback_metrics` and reports together with + the checkpoint on each training epoch end. + + Checkpoints will be saved in the following structure: + + checkpoint_{timestamp}/ Ray Train's checkpoint folder + └─ checkpoint.ckpt Lightning's checkpoint format + + For customized reporting and checkpointing logic, implement your own + `lightning.pytorch.callbacks.Callback` following this user + guide: :ref:`Saving and Loading Checkpoints `. + """ + + def __init__(self) -> None: + # TODO: Upstream this change into ray.train.lightning. + # The difference in this version is removing the trial directory usage. + job_id = ray.get_runtime_context().get_job_id() + experiment_name = ray.train.get_context().get_experiment_name() + self.local_rank = ray.train.get_context().get_local_rank() + + # Create a root temporary directory for storing local checkpoints + # before persisting to storage. + # Lightning's checkpointing implementation requires that this directory + # is a common path across all workers. + # Construct the path prefix with the job id and experiment name, + # which are shared across workers for a Ray Train run. + # This path should not be shared across different Ray Train runs. + self.tmpdir_prefix = Path( + tempfile.gettempdir(), + f"lightning_checkpoints-job_id={job_id}-name={experiment_name}", + ).as_posix() + if os.path.isdir(self.tmpdir_prefix) and self.local_rank == 0: + shutil.rmtree(self.tmpdir_prefix) + + record_extra_usage_tag(TagKey.TRAIN_LIGHTNING_RAYTRAINREPORTCALLBACK, "1") diff --git a/.venv/lib/python3.11/site-packages/ray/train/v2/tensorflow/__init__.py b/.venv/lib/python3.11/site-packages/ray/train/v2/tensorflow/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/ray/train/v2/tensorflow/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/train/v2/tensorflow/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..49863c95eb9af340e35243d20bd2946fc71fc06d Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/train/v2/tensorflow/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/train/v2/tensorflow/__pycache__/tensorflow_trainer.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/train/v2/tensorflow/__pycache__/tensorflow_trainer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e23cfc0f5d668bc650eb3cd7118aee5b653d2890 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/train/v2/tensorflow/__pycache__/tensorflow_trainer.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/train/v2/tensorflow/tensorflow_trainer.py b/.venv/lib/python3.11/site-packages/ray/train/v2/tensorflow/tensorflow_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..61da7d38f92bb4b48c413e57c9d0963e7c35e9b4 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/train/v2/tensorflow/tensorflow_trainer.py @@ -0,0 +1,190 @@ +from typing import Any, Callable, Dict, Optional, Union + +from ray.train import Checkpoint, DataConfig +from ray.train.tensorflow.config import TensorflowConfig +from ray.train.trainer import GenDataset +from ray.train.v2._internal.constants import _UNSUPPORTED +from ray.train.v2.api.config import RunConfig, ScalingConfig +from ray.train.v2.api.data_parallel_trainer import DataParallelTrainer +from ray.util import PublicAPI + + +@PublicAPI(stability="beta") +class TensorflowTrainer(DataParallelTrainer): + """A Trainer for data parallel Tensorflow training. + + At a high level, this Trainer does the following: + + 1. Launches multiple workers as defined by the ``scaling_config``. + 2. Sets up a distributed Tensorflow environment + on these workers as defined by the ``tensorflow_config``. + 3. Ingests the input ``datasets`` based on the ``dataset_config``. + 4. Runs the input ``train_loop_per_worker(train_loop_config)`` + on all workers. + + For more details, see: + + * :ref:`Tensorflow Guide ` + + Inside the ``train_loop_per_worker`` function, you can use any of the + :ref:`Ray Train loop methods `. + + .. warning:: + Ray will not automatically set any environment variables or configuration + related to local parallelism / threading + :ref:`aside from "OMP_NUM_THREADS" `. + If you desire greater control over TensorFlow threading, use + the ``tf.config.threading`` module (eg. + ``tf.config.threading.set_inter_op_parallelism_threads(num_cpus)``) + at the beginning of your ``train_loop_per_worker`` function. + + + .. testcode:: + + from ray import train + + def train_loop_per_worker(): + # Report intermediate results for callbacks or logging and + # checkpoint data. + train.report(...) + + # Returns dict of last saved checkpoint. + train.get_checkpoint() + + # Returns the Dataset shard for the given key. + train.get_dataset_shard("my_dataset") + + # Returns the total number of workers executing training. + train.get_context().get_world_size() + + # Returns the rank of this worker. + train.get_context().get_world_rank() + + # Returns the rank of the worker on the current node. + train.get_context().get_local_rank() + + Any returns from the ``train_loop_per_worker`` will be discarded and not + used or persisted anywhere. + + Example: + + .. testcode:: + + import os + import tempfile + import tensorflow as tf + + import ray + from ray import train + from ray.train import Checkpoint, ScalingConfig + from ray.train.tensorflow import TensorflowTrainer + + def build_model(): + # toy neural network : 1-layer + return tf.keras.Sequential( + [tf.keras.layers.Dense( + 1, activation="linear", input_shape=(1,))] + ) + + def train_loop_per_worker(config): + dataset_shard = train.get_dataset_shard("train") + strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy() + with strategy.scope(): + model = build_model() + model.compile( + optimizer="Adam", loss="mean_squared_error", metrics=["mse"]) + + tf_dataset = dataset_shard.to_tf( + feature_columns="x", + label_columns="y", + batch_size=1 + ) + for epoch in range(config["num_epochs"]): + model.fit(tf_dataset) + + # Create checkpoint. + checkpoint_dir = tempfile.mkdtemp() + model.save_weights( + os.path.join(checkpoint_dir, "my_checkpoint") + ) + checkpoint = Checkpoint.from_directory(checkpoint_dir) + + train.report( + {}, + checkpoint=checkpoint, + ) + + train_dataset = ray.data.from_items([{"x": x, "y": x + 1} for x in range(32)]) + trainer = TensorflowTrainer( + train_loop_per_worker=train_loop_per_worker, + scaling_config=ScalingConfig(num_workers=3, use_gpu=True), + datasets={"train": train_dataset}, + train_loop_config={"num_epochs": 2}, + ) + result = trainer.fit() + + .. testoutput:: + :options:+ELLIPSIS + :hide: + + ... + + Args: + train_loop_per_worker: The training function to execute on each worker. + This function can either take in zero arguments or a single ``Dict`` + argument which is set by defining ``train_loop_config``. + Within this function you can use any of the + :ref:`Ray Train Loop utilities `. + train_loop_config: A configuration ``Dict`` to pass in as an argument to + ``train_loop_per_worker``. + This is typically used for specifying hyperparameters. Passing large + datasets via `train_loop_config` is not recommended and may introduce + large overhead and unknown issues with serialization and deserialization. + tensorflow_config: The configuration for setting up the Tensorflow + Distributed backend. If set to None, a default configuration will be + used in which GPU training uses NCCL and CPU training uses Gloo. + scaling_config: The configuration for how to scale data parallel training. + ``num_workers`` determines how many Python processes are used for training, + and ``use_gpu`` determines whether or not each process should use GPUs. + See :class:`~ray.train.ScalingConfig` for more info. + run_config: The configuration for the execution of the training run. + See :class:`~ray.train.RunConfig` for more info. + datasets: The Ray Datasets to ingest for training. + Datasets are keyed by name (``{name: dataset}``). + Each dataset can be accessed from within the ``train_loop_per_worker`` + by calling ``ray.train.get_dataset_shard(name)``. + Sharding and additional configuration can be done by + passing in a ``dataset_config``. + resume_from_checkpoint: A checkpoint to resume training from. + metadata: Dict that should be made available via + `ray.train.get_context().get_metadata()` and in `checkpoint.get_metadata()` + for checkpoints saved from this Trainer. Must be JSON-serializable. + """ + + def __init__( + self, + train_loop_per_worker: Union[Callable[[], None], Callable[[Dict], None]], + *, + train_loop_config: Optional[Dict] = None, + tensorflow_config: Optional[TensorflowConfig] = None, + scaling_config: Optional[ScalingConfig] = None, + dataset_config: Optional[DataConfig] = None, + run_config: Optional[RunConfig] = None, + datasets: Optional[Dict[str, GenDataset]] = None, + metadata: Optional[Dict[str, Any]] = _UNSUPPORTED, + resume_from_checkpoint: Optional[Checkpoint] = None, + ): + if not tensorflow_config: + tensorflow_config = TensorflowConfig() + + super(TensorflowTrainer, self).__init__( + train_loop_per_worker=train_loop_per_worker, + train_loop_config=train_loop_config, + backend_config=tensorflow_config, + scaling_config=scaling_config, + dataset_config=dataset_config, + run_config=run_config, + datasets=datasets, + resume_from_checkpoint=resume_from_checkpoint, + metadata=metadata, + ) diff --git a/.venv/lib/python3.11/site-packages/ray/train/v2/torch/__init__.py b/.venv/lib/python3.11/site-packages/ray/train/v2/torch/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/ray/train/v2/torch/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/train/v2/torch/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f18107612d5ba367aecd3428d985e32bd27d4c3b Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/train/v2/torch/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/train/v2/torch/__pycache__/torch_trainer.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/train/v2/torch/__pycache__/torch_trainer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..543de747529652da3ef7ca7f4b196f4e62fd8063 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/train/v2/torch/__pycache__/torch_trainer.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/train/v2/torch/__pycache__/train_loop_utils.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/train/v2/torch/__pycache__/train_loop_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..57b74f117ff6fdb515a1046f23e823af332932e7 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/train/v2/torch/__pycache__/train_loop_utils.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/train/v2/torch/torch_trainer.py b/.venv/lib/python3.11/site-packages/ray/train/v2/torch/torch_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..171e10564aeecbb532b620ba25603ca19c6a9e8d --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/train/v2/torch/torch_trainer.py @@ -0,0 +1,207 @@ +from typing import Any, Callable, Dict, Optional, Union + +from ray.train import Checkpoint, DataConfig +from ray.train.torch import TorchConfig +from ray.train.trainer import GenDataset +from ray.train.v2.api.config import RunConfig, ScalingConfig +from ray.train.v2.api.data_parallel_trainer import DataParallelTrainer +from ray.util import PublicAPI + + +@PublicAPI(stability="stable") +class TorchTrainer(DataParallelTrainer): + """A Trainer for data parallel PyTorch training. + + At a high level, this Trainer does the following: + + 1. Launches multiple workers as defined by the ``scaling_config``. + 2. Sets up a distributed PyTorch environment + on these workers as defined by the ``torch_config``. + 3. Ingests the input ``datasets`` based on the ``dataset_config``. + 4. Runs the input ``train_loop_per_worker(train_loop_config)`` + on all workers. + + For more details, see: + + * :ref:`PyTorch Guide ` + * :ref:`PyTorch Lightning Guide ` + * :ref:`Hugging Face Transformers Guide ` + + Example: + + .. testcode:: + + import os + import tempfile + + import torch + from torch import nn + from torch.nn.parallel import DistributedDataParallel + + import ray + from ray.train import Checkpoint, CheckpointConfig, RunConfig, ScalingConfig + from ray.train.torch import TorchTrainer + + # If using GPUs, set this to True. + use_gpu = False + # Number of processes to run training on. + num_workers = 4 + + # Define your network structure. + class NeuralNetwork(nn.Module): + def __init__(self): + super(NeuralNetwork, self).__init__() + self.layer1 = nn.Linear(1, 32) + self.relu = nn.ReLU() + self.layer2 = nn.Linear(32, 1) + + def forward(self, input): + return self.layer2(self.relu(self.layer1(input))) + + # Training loop. + def train_loop_per_worker(config): + + # Read configurations. + lr = config["lr"] + batch_size = config["batch_size"] + num_epochs = config["num_epochs"] + + # Fetch training dataset. + train_dataset_shard = ray.train.get_dataset_shard("train") + + # Instantiate and prepare model for training. + model = NeuralNetwork() + model = ray.train.torch.prepare_model(model) + + # Define loss and optimizer. + loss_fn = nn.MSELoss() + optimizer = torch.optim.SGD(model.parameters(), lr=lr) + + # Create data loader. + dataloader = train_dataset_shard.iter_torch_batches( + batch_size=batch_size, dtypes=torch.float + ) + + # Train multiple epochs. + for epoch in range(num_epochs): + + # Train epoch. + for batch in dataloader: + output = model(batch["input"]) + loss = loss_fn(output, batch["label"]) + optimizer.zero_grad() + loss.backward() + optimizer.step() + + # Create checkpoint. + base_model = (model.module + if isinstance(model, DistributedDataParallel) else model) + checkpoint_dir = tempfile.mkdtemp() + torch.save( + {"model_state_dict": base_model.state_dict()}, + os.path.join(checkpoint_dir, "model.pt"), + ) + checkpoint = Checkpoint.from_directory(checkpoint_dir) + + # Report metrics and checkpoint. + ray.train.report({"loss": loss.item()}, checkpoint=checkpoint) + + + # Define configurations. + train_loop_config = {"num_epochs": 20, "lr": 0.01, "batch_size": 32} + scaling_config = ScalingConfig(num_workers=num_workers, use_gpu=use_gpu) + run_config = RunConfig(checkpoint_config=CheckpointConfig(num_to_keep=1)) + + # Define datasets. + train_dataset = ray.data.from_items( + [{"input": [x], "label": [2 * x + 1]} for x in range(2000)] + ) + datasets = {"train": train_dataset} + + # Initialize the Trainer. + trainer = TorchTrainer( + train_loop_per_worker=train_loop_per_worker, + train_loop_config=train_loop_config, + scaling_config=scaling_config, + run_config=run_config, + datasets=datasets + ) + + # Train the model. + result = trainer.fit() + + # Inspect the results. + final_loss = result.metrics["loss"] + + .. testoutput:: + :hide: + + ... + + Args: + + train_loop_per_worker: The training function to execute on each worker. + This function can either take in zero arguments or a single ``Dict`` + argument which is set by defining ``train_loop_config``. + Within this function you can use any of the + :ref:`Ray Train Loop utilities `. + train_loop_config: A configuration ``Dict`` to pass in as an argument to + ``train_loop_per_worker``. + This is typically used for specifying hyperparameters. Passing large + datasets via `train_loop_config` is not recommended and may introduce + large overhead and unknown issues with serialization and deserialization. + torch_config: The configuration for setting up the PyTorch Distributed backend. + If set to None, a default configuration will be used in which + GPU training uses NCCL and CPU training uses Gloo. + scaling_config: The configuration for how to scale data parallel training. + ``num_workers`` determines how many Python processes are used for training, + and ``use_gpu`` determines whether or not each process should use GPUs. + See :class:`~ray.train.ScalingConfig` for more info. + run_config: The configuration for the execution of the training run. + See :class:`~ray.train.RunConfig` for more info. + datasets: The Ray Datasets to ingest for training. + Datasets are keyed by name (``{name: dataset}``). + Each dataset can be accessed from within the ``train_loop_per_worker`` + by calling ``ray.train.get_dataset_shard(name)``. + Sharding and additional configuration can be done by + passing in a ``dataset_config``. + dataset_config: The configuration for ingesting the input ``datasets``. + By default, all the Ray Dataset are split equally across workers. + See :class:`~ray.train.DataConfig` for more details. + resume_from_checkpoint: A checkpoint to resume training from. + This checkpoint can be accessed from within ``train_loop_per_worker`` + by calling ``ray.train.get_checkpoint()``. + metadata: Dict that should be made available via + `ray.train.get_context().get_metadata()` and in `checkpoint.get_metadata()` + for checkpoints saved from this Trainer. Must be JSON-serializable. + """ + + def __init__( + self, + train_loop_per_worker: Union[Callable[[], None], Callable[[Dict], None]], + *, + train_loop_config: Optional[Dict] = None, + torch_config: Optional[TorchConfig] = None, + scaling_config: Optional[ScalingConfig] = None, + run_config: Optional[RunConfig] = None, + datasets: Optional[Dict[str, GenDataset]] = None, + dataset_config: Optional[DataConfig] = None, + metadata: Optional[Dict[str, Any]] = None, + resume_from_checkpoint: Optional[Checkpoint] = None, + ): + torch_config = torch_config or TorchConfig() + if not torch_config.backend: + torch_config.backend = "nccl" if scaling_config.use_gpu else "gloo" + + super(TorchTrainer, self).__init__( + train_loop_per_worker=train_loop_per_worker, + train_loop_config=train_loop_config, + backend_config=torch_config, + scaling_config=scaling_config, + run_config=run_config, + dataset_config=dataset_config, + datasets=datasets, + # TODO: Re-enable below. + # resume_from_checkpoint=resume_from_checkpoint, + # metadata=metadata, + ) diff --git a/.venv/lib/python3.11/site-packages/ray/train/v2/torch/train_loop_utils.py b/.venv/lib/python3.11/site-packages/ray/train/v2/torch/train_loop_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e641feb21e240a455e314cd0323993ad43b38fcb --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/train/v2/torch/train_loop_utils.py @@ -0,0 +1,293 @@ +import logging +import os +import random +from typing import Any, Callable, Dict, Optional, Union + +import numpy as np +import torch +from packaging.version import Version +from torch.nn.parallel import DistributedDataParallel +from torch.utils.data import ( + DataLoader, + DistributedSampler, + IterableDataset, + RandomSampler, + SequentialSampler, +) + +import ray.train.torch +from ray._private.usage.usage_lib import TagKey, record_extra_usage_tag +from ray.train.torch.train_loop_utils import _WrappedDataLoader +from ray.util.annotations import PublicAPI + +if Version(torch.__version__) < Version("1.11.0"): + FullyShardedDataParallel = None +else: + from torch.distributed.fsdp import FullyShardedDataParallel + +logger = logging.getLogger(__name__) + + +def prepare_model( + model: torch.nn.Module, + move_to_device: Union[bool, torch.device] = True, + parallel_strategy: Optional[str] = "ddp", + parallel_strategy_kwargs: Optional[Dict[str, Any]] = None, +) -> torch.nn.Module: + """Prepares the model for distributed execution. + + This allows you to use the same exact code regardless of number of + workers or the device type being used (CPU, GPU). + + Args: + model (torch.nn.Module): A torch model to prepare. + move_to_device: Either a boolean indiciating whether to move + the model to the correct device or an actual device to + move the model to. If set to False, the model needs + to manually be moved to the correct device. + parallel_strategy ("ddp", "fsdp", or None): Whether to wrap models + in ``DistributedDataParallel``, ``FullyShardedDataParallel``, + or neither. + parallel_strategy_kwargs (Dict[str, Any]): Args to pass into + ``DistributedDataParallel`` or ``FullyShardedDataParallel`` + initialization if ``parallel_strategy`` is set to "ddp" + or "fsdp", respectively. + """ + if parallel_strategy == "fsdp" and FullyShardedDataParallel is None: + raise ImportError( + "FullyShardedDataParallel requires torch>=1.11.0. " + "Run `pip install 'torch>=1.11.0'` to use FullyShardedDataParallel." + ) + + record_extra_usage_tag(TagKey.TRAIN_TORCH_PREPARE_MODEL, "1") + + parallel_strategy_kwargs = parallel_strategy_kwargs or {} + + rank = ray.train.get_context().get_local_rank() + + if isinstance(move_to_device, torch.device): + device = move_to_device + else: + device = ray.train.torch.get_device() + if isinstance(device, list): + device = device[0] + + if torch.cuda.is_available(): + torch.cuda.set_device(device) + + if move_to_device: + if rank == 0: + logger.info(f"Moving model to device: {device}") + else: + logger.debug(f"Moving model to device: {device}") + model = model.to(device) + + world_size = ray.train.get_context().get_world_size() + + if parallel_strategy and world_size > 1: + if parallel_strategy == "ddp": + DataParallel = DistributedDataParallel + if torch.cuda.is_available(): + parallel_strategy_kwargs = { + "device_ids": [device], + "output_device": device, + **parallel_strategy_kwargs, + } + else: + if not torch.cuda.is_available(): + raise RuntimeError( + "FSDP is only available with GPU-enabled " + "training. Set " + "`use_gpu=True` in your Trainer to train with " + "GPUs." + ) + DataParallel = FullyShardedDataParallel + if rank == 0: + logger.info(f"Wrapping provided model in {DataParallel.__name__}.") + else: + logger.debug(f"Wrapping provided model in {DataParallel.__name__}.") + model = DataParallel(model, **parallel_strategy_kwargs) + + return model + + +@PublicAPI(stability="stable") +def prepare_data_loader( + data_loader: torch.utils.data.DataLoader, + add_dist_sampler: bool = True, + move_to_device: bool = True, + auto_transfer: bool = True, +) -> torch.utils.data.DataLoader: + """Prepares :class:`~torch.utils.data.DataLoader` for distributed execution. + + This allows you to use the same exact code regardless of number of + workers or the device type being used (CPU, GPU). + + .. note:: + + This method adds a `DistributedSampler` to the `DataLoader` if the + number of training workers is greater than 1. If shuffling is + enabled on the original `DataLoader`, then `shuffle=True` will also + be passed into the `DistributedSampler` constructor. `shuffle=False` + on the original `DataLoader` also means that shuffling is disabled + on the sampler. + + With more than 1 worker, calling the `DistributedSampler.set_epoch` method + at the beginning of each epoch before creating the DataLoader iterator + is necessary to make shuffling work properly across multiple epochs. + Otherwise, the same ordering will be always used. + See: https://pytorch.org/docs/stable/data.html#torch.utils.data.distributed.DistributedSampler # noqa: E501 + + Example: + + .. testcode: + :skipif: True + + import torch + + import ray.train.torch + + train_dataloader = torch.utils.data.DataLoader( + ..., batch_size=..., shuffle=True + ) + train_dataloader = ray.train.torch.prepare_data_loader(train_loader) + + for epoch in range(10): + if ray.train.get_context().get_world_size() > 1: + # Required for the distributed sampler to shuffle properly across epochs + train_dataloader.sampler.set_epoch(epoch) + + for X, y in train_loader: + # No need to move data to GPU, this is done by `prepare_data_loader`! + # X, y = X.to("cuda"), y.to("cuda") + ... + + Args: + data_loader (torch.utils.data.DataLoader): The DataLoader to + prepare. + add_dist_sampler: Whether to add a DistributedSampler to + the provided DataLoader. + move_to_device: If set, automatically move the data + returned by the data loader to the correct device. + auto_transfer: If set and device is GPU, another CUDA stream + is created to automatically copy data from host (CPU) memory + to device (GPU) memory (the default CUDA stream still runs the + training procedure). If device is CPU, it will be disabled + regardless of the setting. This configuration will be ignored + if ``move_to_device`` is False. + """ + record_extra_usage_tag(TagKey.TRAIN_TORCH_PREPARE_DATALOADER, "1") + + world_size = ray.train.get_context().get_world_size() + world_rank = ray.train.get_context().get_world_rank() + + # Only add Distributed Sampler if the following conditions hold: + # 1. More than one training worker is being used. + # 2. A DistributedSampler has not already been added by the user. + # 3. The dataset is not an IterableDataset. Samplers do not worker with + # IterableDatasets. + if ( + world_size > 1 + and not isinstance(data_loader.sampler, DistributedSampler) + and not ( + hasattr(data_loader, "dataset") + and isinstance(data_loader.dataset, IterableDataset) + ) + and add_dist_sampler + ): + + def with_sampler(loader): + # Automatically set the DistributedSampler + + # If you're using a sampler, the DataLoader shuffle flag must be set to + # False. Shuffling is instead determined by the shuffle argument passed + # to the DistributedSampler constructor. + + # If no sampler is passed to the DataLoader constructor, Torch + # constructs a default sampler. The default sampler is a RandomSampler + # if shuffling is enabled and a SequentialSampler otherwise. DataLoader + # does not have a shuffle attribute, so we instead identify whether + # shuffling is enabled by checking the default sampler type. + shuffle = not isinstance(loader.sampler, SequentialSampler) + worker_init_fn: Optional[Callable[[int], None]] = loader.worker_init_fn + generator: Optional[torch.Generator] = loader.generator + + using_default_sampler = isinstance( + loader.sampler, (SequentialSampler, RandomSampler) + ) + if not using_default_sampler and world_rank == 0: + logger.warning( + f"The {loader.sampler.__class__.__name__} will be overwritten " + "with a DistributedSampler. You can disable this by setting " + "`with_sampler` to False in `prepare_data_loader`." + ) + + data_loader_args = { + "dataset": loader.dataset, + "batch_size": loader.batch_size, + "shuffle": False, + "num_workers": loader.num_workers, + "collate_fn": loader.collate_fn, + "pin_memory": loader.pin_memory, + "drop_last": loader.drop_last, + "timeout": loader.timeout, + "worker_init_fn": worker_init_fn, + "generator": generator, + "sampler": DistributedSampler(loader.dataset, shuffle=shuffle), + } + return DataLoader(**data_loader_args) + + data_loader = with_sampler(data_loader) + + if move_to_device: + device = ray.train.torch.get_device() + data_loader = _WrappedDataLoader(data_loader, device, auto_transfer) + + return data_loader + + +@PublicAPI(stability="beta") +def accelerate(amp: bool = False) -> None: + raise NotImplementedError + + +@PublicAPI(stability="beta") +def prepare_optimizer(optimizer: torch.optim.Optimizer) -> torch.optim.Optimizer: + raise NotImplementedError + + +@PublicAPI(stability="beta") +def backward(tensor: torch.Tensor) -> None: + raise NotImplementedError + + +@PublicAPI(stability="stable") +def enable_reproducibility(seed: int = 0) -> None: + """Limits sources of nondeterministic behavior. + + This function: + + * Seeds PyTorch, Python, and NumPy. + * Disables CUDA convolution benchmarking. + * Configures PyTorch to use determinstic algorithms. + * Seeds workers spawned for multi-process data loading. + + Args: + seed: The number to seed libraries and data workers with. + + .. warning:: ``train.torch.enable_reproducibility()`` can't guarantee + completely reproducible results across executions. To learn more, read + the `PyTorch notes on randomness + `_. + """ + torch.manual_seed(seed) + random.seed(seed) + np.random.seed(seed) + + torch.use_deterministic_algorithms(True) + torch.backends.cudnn.benchmark = False + + # If you want to use deterministic algorithms with CUDA, then you need to set + # the CUBLAS_WORKSPACE_CONFIG environment variable; otherwise, Torch errors. + # See https://docs.nvidia.com/cuda/cublas/index.html#cublasApi_reproducibility. + os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" diff --git a/.venv/lib/python3.11/site-packages/ray/train/v2/xgboost/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/train/v2/xgboost/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d5c7727a8f5385dd1a947d402a6edd6f5ab1f0d2 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/train/v2/xgboost/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/train/v2/xgboost/__pycache__/xgboost_trainer.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/train/v2/xgboost/__pycache__/xgboost_trainer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3b29f55e5626556c40fae544bf8779a8787ace6f Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/train/v2/xgboost/__pycache__/xgboost_trainer.cpython-311.pyc differ