Spaces:
Running
Running
| from __future__ import absolute_import, division, print_function | |
| import os | |
| import warnings | |
| from random import random | |
| from time import sleep | |
| from uuid import uuid4 | |
| import pytest | |
| from .. import Parallel, delayed, parallel_backend, parallel_config | |
| from .._dask import DaskDistributedBackend | |
| from ..parallel import AutoBatchingMixin, ThreadingBackend | |
| from .common import np, with_numpy | |
| from .test_parallel import ( | |
| _recursive_backend_info, | |
| _test_deadlock_with_generator, | |
| _test_parallel_unordered_generator_returns_fastest_first, # noqa: E501 | |
| ) | |
| distributed = pytest.importorskip("distributed") | |
| dask = pytest.importorskip("dask") | |
| # These imports need to be after the pytest.importorskip hence the noqa: E402 | |
| from distributed import Client, LocalCluster, get_client # noqa: E402 | |
| from distributed.metrics import time # noqa: E402 | |
| # Note: pytest requires to manually import all fixtures used in the test | |
| # and their dependencies. | |
| from distributed.utils_test import cleanup, cluster, inc # noqa: E402, F401 | |
| def avoid_dask_env_leaks(tmp_path): | |
| # when starting a dask nanny, the environment variable might change. | |
| # this fixture makes sure the environment is reset after the test. | |
| from joblib._parallel_backends import ParallelBackendBase | |
| old_value = {k: os.environ.get(k) for k in ParallelBackendBase.MAX_NUM_THREADS_VARS} | |
| yield | |
| # Reset the environment variables to their original values | |
| for k, v in old_value.items(): | |
| if v is None: | |
| os.environ.pop(k, None) | |
| else: | |
| os.environ[k] = v | |
| def noop(*args, **kwargs): | |
| pass | |
| def slow_raise_value_error(condition, duration=0.05): | |
| sleep(duration) | |
| if condition: | |
| raise ValueError("condition evaluated to True") | |
| def count_events(event_name, client): | |
| worker_events = client.run(lambda dask_worker: dask_worker.log) | |
| event_counts = {} | |
| for w, events in worker_events.items(): | |
| event_counts[w] = len( | |
| [event for event in list(events) if event[1] == event_name] | |
| ) | |
| return event_counts | |
| def test_simple(loop): | |
| with cluster() as (s, [a, b]): | |
| with Client(s["address"], loop=loop) as client: # noqa: F841 | |
| with parallel_config(backend="dask"): | |
| seq = Parallel()(delayed(inc)(i) for i in range(10)) | |
| assert seq == [inc(i) for i in range(10)] | |
| with pytest.raises(ValueError): | |
| Parallel()( | |
| delayed(slow_raise_value_error)(i == 3) for i in range(10) | |
| ) | |
| seq = Parallel()(delayed(inc)(i) for i in range(10)) | |
| assert seq == [inc(i) for i in range(10)] | |
| def test_dask_backend_uses_autobatching(loop): | |
| assert ( | |
| DaskDistributedBackend.compute_batch_size | |
| is AutoBatchingMixin.compute_batch_size | |
| ) | |
| with cluster() as (s, [a, b]): | |
| with Client(s["address"], loop=loop) as client: # noqa: F841 | |
| with parallel_config(backend="dask"): | |
| with Parallel() as parallel: | |
| # The backend should be initialized with a default | |
| # batch size of 1: | |
| backend = parallel._backend | |
| assert isinstance(backend, DaskDistributedBackend) | |
| assert backend.parallel is parallel | |
| assert backend._effective_batch_size == 1 | |
| # Launch many short tasks that should trigger | |
| # auto-batching: | |
| parallel(delayed(lambda: None)() for _ in range(int(1e4))) | |
| assert backend._effective_batch_size > 10 | |
| def test_parallel_unordered_generator_returns_fastest_first_with_dask(n_jobs, context): | |
| with distributed.Client(n_workers=2, threads_per_worker=2), context("dask"): | |
| _test_parallel_unordered_generator_returns_fastest_first(None, n_jobs) | |
| def test_deadlock_with_generator_and_dask(context, return_as, n_jobs): | |
| with distributed.Client(n_workers=2, threads_per_worker=2), context("dask"): | |
| _test_deadlock_with_generator(None, return_as, n_jobs) | |
| def test_nested_parallelism_with_dask(context): | |
| with distributed.Client(n_workers=2, threads_per_worker=2): | |
| # 10 MB of data as argument to trigger implicit scattering | |
| data = np.ones(int(1e7), dtype=np.uint8) | |
| for i in range(2): | |
| with context("dask"): | |
| backend_types_and_levels = _recursive_backend_info(data=data) | |
| assert len(backend_types_and_levels) == 4 | |
| assert all( | |
| name == "DaskDistributedBackend" for name, _ in backend_types_and_levels | |
| ) | |
| # No argument | |
| with context("dask"): | |
| backend_types_and_levels = _recursive_backend_info() | |
| assert len(backend_types_and_levels) == 4 | |
| assert all( | |
| name == "DaskDistributedBackend" for name, _ in backend_types_and_levels | |
| ) | |
| def random2(): | |
| return random() | |
| def test_dont_assume_function_purity(loop): | |
| with cluster() as (s, [a, b]): | |
| with Client(s["address"], loop=loop) as client: # noqa: F841 | |
| with parallel_config(backend="dask"): | |
| x, y = Parallel()(delayed(random2)() for i in range(2)) | |
| assert x != y | |
| def test_dask_funcname(loop, mixed): | |
| from joblib._dask import Batch | |
| if not mixed: | |
| tasks = [delayed(inc)(i) for i in range(4)] | |
| batch_repr = "batch_of_inc_4_calls" | |
| else: | |
| tasks = [delayed(abs)(i) if i % 2 else delayed(inc)(i) for i in range(4)] | |
| batch_repr = "mixed_batch_of_inc_4_calls" | |
| assert repr(Batch(tasks)) == batch_repr | |
| with cluster() as (s, [a, b]): | |
| with Client(s["address"], loop=loop) as client: | |
| with parallel_config(backend="dask"): | |
| _ = Parallel(batch_size=2, pre_dispatch="all")(tasks) | |
| def f(dask_scheduler): | |
| return list(dask_scheduler.transition_log) | |
| batch_repr = batch_repr.replace("4", "2") | |
| log = client.run_on_scheduler(f) | |
| assert all("batch_of_inc" in tup[0] for tup in log) | |
| def test_no_undesired_distributed_cache_hit(): | |
| # Dask has a pickle cache for callables that are called many times. Because | |
| # the dask backends used to wrap both the functions and the arguments | |
| # under instances of the Batch callable class this caching mechanism could | |
| # lead to bugs as described in: https://github.com/joblib/joblib/pull/1055 | |
| # The joblib-dask backend has been refactored to avoid bundling the | |
| # arguments as an attribute of the Batch instance to avoid this problem. | |
| # This test serves as non-regression problem. | |
| # Use a large number of input arguments to give the AutoBatchingMixin | |
| # enough tasks to kick-in. | |
| lists = [[] for _ in range(100)] | |
| np = pytest.importorskip("numpy") | |
| X = np.arange(int(1e6)) | |
| def isolated_operation(list_, data=None): | |
| if data is not None: | |
| np.testing.assert_array_equal(data, X) | |
| list_.append(uuid4().hex) | |
| return list_ | |
| cluster = LocalCluster(n_workers=1, threads_per_worker=2) | |
| client = Client(cluster) | |
| try: | |
| with parallel_config(backend="dask"): | |
| # dispatches joblib.parallel.BatchedCalls | |
| res = Parallel()(delayed(isolated_operation)(list_) for list_ in lists) | |
| # The original arguments should not have been mutated as the mutation | |
| # happens in the dask worker process. | |
| assert lists == [[] for _ in range(100)] | |
| # Here we did not pass any large numpy array as argument to | |
| # isolated_operation so no scattering event should happen under the | |
| # hood. | |
| counts = count_events("receive-from-scatter", client) | |
| assert sum(counts.values()) == 0 | |
| assert all([len(r) == 1 for r in res]) | |
| with parallel_config(backend="dask"): | |
| # Append a large array which will be scattered by dask, and | |
| # dispatch joblib._dask.Batch | |
| res = Parallel()( | |
| delayed(isolated_operation)(list_, data=X) for list_ in lists | |
| ) | |
| # This time, auto-scattering should have kicked it. | |
| counts = count_events("receive-from-scatter", client) | |
| assert sum(counts.values()) > 0 | |
| assert all([len(r) == 1 for r in res]) | |
| finally: | |
| client.close(timeout=30) | |
| cluster.close(timeout=30) | |
| class CountSerialized(object): | |
| def __init__(self, x): | |
| self.x = x | |
| self.count = 0 | |
| def __add__(self, other): | |
| return self.x + getattr(other, "x", other) | |
| __radd__ = __add__ | |
| def __reduce__(self): | |
| self.count += 1 | |
| return (CountSerialized, (self.x,)) | |
| def add5(a, b, c, d=0, e=0): | |
| return a + b + c + d + e | |
| def test_manual_scatter(loop): | |
| # Let's check that the number of times scattered and non-scattered | |
| # variables are serialized is consistent between `joblib.Parallel` calls | |
| # and equivalent native `client.submit` call. | |
| # Number of serializations can vary from dask to another, so this test only | |
| # checks that `joblib.Parallel` does not add more serialization steps than | |
| # a native `client.submit` call, but does not check for an exact number of | |
| # serialization steps. | |
| w, x, y, z = (CountSerialized(i) for i in range(4)) | |
| f = delayed(add5) | |
| tasks = [f(x, y, z, d=4, e=5) for _ in range(10)] | |
| tasks += [ | |
| f(x, z, y, d=5, e=4), | |
| f(y, x, z, d=x, e=5), | |
| f(z, z, x, d=z, e=y), | |
| ] | |
| expected = [func(*args, **kwargs) for func, args, kwargs in tasks] | |
| with cluster() as (s, _): | |
| with Client(s["address"], loop=loop) as client: # noqa: F841 | |
| with parallel_config(backend="dask", scatter=[w, x, y]): | |
| results_parallel = Parallel(batch_size=1)(tasks) | |
| assert results_parallel == expected | |
| # Check that an error is raised for bad arguments, as scatter must | |
| # take a list/tuple | |
| with pytest.raises(TypeError): | |
| with parallel_config(backend="dask", loop=loop, scatter=1): | |
| pass | |
| # Scattered variables only serialized during scatter. Checking with an | |
| # extra variable as this count can vary from one dask version | |
| # to another. | |
| n_serialization_scatter_with_parallel = w.count | |
| assert x.count == n_serialization_scatter_with_parallel | |
| assert y.count == n_serialization_scatter_with_parallel | |
| n_serialization_with_parallel = z.count | |
| # Reset the cluster and the serialization count | |
| for var in (w, x, y, z): | |
| var.count = 0 | |
| with cluster() as (s, _): | |
| with Client(s["address"], loop=loop) as client: # noqa: F841 | |
| scattered = dict() | |
| for obj in w, x, y: | |
| scattered[id(obj)] = client.scatter(obj, broadcast=True) | |
| results_native = [ | |
| client.submit( | |
| func, | |
| *(scattered.get(id(arg), arg) for arg in args), | |
| **dict( | |
| (key, scattered.get(id(value), value)) | |
| for (key, value) in kwargs.items() | |
| ), | |
| key=str(uuid4()), | |
| ).result() | |
| for (func, args, kwargs) in tasks | |
| ] | |
| assert results_native == expected | |
| # Now check that the number of serialization steps is the same for joblib | |
| # and native dask calls. | |
| n_serialization_scatter_native = w.count | |
| assert x.count == n_serialization_scatter_native | |
| assert y.count == n_serialization_scatter_native | |
| assert n_serialization_scatter_with_parallel == n_serialization_scatter_native | |
| distributed_version = tuple(int(v) for v in distributed.__version__.split(".")) | |
| if distributed_version < (2023, 4): | |
| # Previous to 2023.4, the serialization was adding an extra call to | |
| # __reduce__ for the last job `f(z, z, x, d=z, e=y)`, because `z` | |
| # appears both in the args and kwargs, which is not the case when | |
| # running with joblib. Cope with this discrepancy. | |
| assert z.count == n_serialization_with_parallel + 1 | |
| else: | |
| assert z.count == n_serialization_with_parallel | |
| # When the same IOLoop is used for multiple clients in a row, use | |
| # loop_in_thread instead of loop to prevent the Client from closing it. See | |
| # dask/distributed #4112 | |
| def test_auto_scatter(loop_in_thread): | |
| np = pytest.importorskip("numpy") | |
| data1 = np.ones(int(1e4), dtype=np.uint8) | |
| data2 = np.ones(int(1e4), dtype=np.uint8) | |
| data_to_process = ([data1] * 3) + ([data2] * 3) | |
| with cluster() as (s, [a, b]): | |
| with Client(s["address"], loop=loop_in_thread) as client: | |
| with parallel_config(backend="dask"): | |
| # Passing the same data as arg and kwarg triggers a single | |
| # scatter operation whose result is reused. | |
| Parallel()( | |
| delayed(noop)(data, data, i, opt=data) | |
| for i, data in enumerate(data_to_process) | |
| ) | |
| # By default large array are automatically scattered with | |
| # broadcast=1 which means that one worker must directly receive | |
| # the data from the scatter operation once. | |
| counts = count_events("receive-from-scatter", client) | |
| assert counts[a["address"]] + counts[b["address"]] == 2 | |
| with cluster() as (s, [a, b]): | |
| with Client(s["address"], loop=loop_in_thread) as client: | |
| with parallel_config(backend="dask"): | |
| Parallel()(delayed(noop)(data1[:3], i) for i in range(5)) | |
| # Small arrays are passed within the task definition without going | |
| # through a scatter operation. | |
| counts = count_events("receive-from-scatter", client) | |
| assert counts[a["address"]] == 0 | |
| assert counts[b["address"]] == 0 | |
| def test_nested_scatter(loop, retry_no): | |
| np = pytest.importorskip("numpy") | |
| NUM_INNER_TASKS = 10 | |
| NUM_OUTER_TASKS = 10 | |
| def my_sum(x, i, j): | |
| return np.sum(x) | |
| def outer_function_joblib(array, i): | |
| client = get_client() # noqa | |
| with parallel_config(backend="dask"): | |
| results = Parallel()( | |
| delayed(my_sum)(array[j:], i, j) for j in range(NUM_INNER_TASKS) | |
| ) | |
| return sum(results) | |
| with cluster() as (s, [a, b]): | |
| with Client(s["address"], loop=loop) as _: | |
| with parallel_config(backend="dask"): | |
| my_array = np.ones(10000) | |
| _ = Parallel()( | |
| delayed(outer_function_joblib)(my_array[i:], i) | |
| for i in range(NUM_OUTER_TASKS) | |
| ) | |
| def test_nested_backend_context_manager(loop_in_thread): | |
| def get_nested_pids(): | |
| pids = set(Parallel(n_jobs=2)(delayed(os.getpid)() for _ in range(2))) | |
| pids |= set(Parallel(n_jobs=2)(delayed(os.getpid)() for _ in range(2))) | |
| return pids | |
| with cluster() as (s, [a, b]): | |
| with Client(s["address"], loop=loop_in_thread) as client: | |
| with parallel_config(backend="dask"): | |
| pid_groups = Parallel(n_jobs=2)( | |
| delayed(get_nested_pids)() for _ in range(10) | |
| ) | |
| for pid_group in pid_groups: | |
| assert len(set(pid_group)) <= 2 | |
| # No deadlocks | |
| with Client(s["address"], loop=loop_in_thread) as client: # noqa: F841 | |
| with parallel_config(backend="dask"): | |
| pid_groups = Parallel(n_jobs=2)( | |
| delayed(get_nested_pids)() for _ in range(10) | |
| ) | |
| for pid_group in pid_groups: | |
| assert len(set(pid_group)) <= 2 | |
| def test_nested_backend_context_manager_implicit_n_jobs(loop): | |
| # Check that Parallel with no explicit n_jobs value automatically selects | |
| # all the dask workers, including in nested calls. | |
| def _backend_type(p): | |
| return p._backend.__class__.__name__ | |
| def get_nested_implicit_n_jobs(): | |
| with Parallel() as p: | |
| return _backend_type(p), p.n_jobs | |
| with cluster() as (s, [a, b]): | |
| with Client(s["address"], loop=loop) as client: # noqa: F841 | |
| with parallel_config(backend="dask"): | |
| with Parallel() as p: | |
| assert _backend_type(p) == "DaskDistributedBackend" | |
| assert p.n_jobs == -1 | |
| all_nested_n_jobs = p( | |
| delayed(get_nested_implicit_n_jobs)() for _ in range(2) | |
| ) | |
| for backend_type, nested_n_jobs in all_nested_n_jobs: | |
| assert backend_type == "DaskDistributedBackend" | |
| assert nested_n_jobs == -1 | |
| def test_errors(loop): | |
| with pytest.raises(ValueError) as info: | |
| with parallel_config(backend="dask"): | |
| pass | |
| assert "create a dask client" in str(info.value).lower() | |
| def test_correct_nested_backend(loop): | |
| with cluster() as (s, [a, b]): | |
| with Client(s["address"], loop=loop) as client: # noqa: F841 | |
| # No requirement, should be us | |
| with parallel_config(backend="dask"): | |
| result = Parallel(n_jobs=2)( | |
| delayed(outer)(nested_require=None) for _ in range(1) | |
| ) | |
| assert isinstance(result[0][0][0], DaskDistributedBackend) | |
| # Require threads, should be threading | |
| with parallel_config(backend="dask"): | |
| result = Parallel(n_jobs=2)( | |
| delayed(outer)(nested_require="sharedmem") for _ in range(1) | |
| ) | |
| assert isinstance(result[0][0][0], ThreadingBackend) | |
| def outer(nested_require): | |
| return Parallel(n_jobs=2, prefer="threads")( | |
| delayed(middle)(nested_require) for _ in range(1) | |
| ) | |
| def middle(require): | |
| return Parallel(n_jobs=2, require=require)(delayed(inner)() for _ in range(1)) | |
| def inner(): | |
| return Parallel()._backend | |
| def test_secede_with_no_processes(loop): | |
| # https://github.com/dask/distributed/issues/1775 | |
| with Client(loop=loop, processes=False, set_as_default=True): | |
| with parallel_config(backend="dask"): | |
| Parallel(n_jobs=4)(delayed(id)(i) for i in range(2)) | |
| def _worker_address(_): | |
| from distributed import get_worker | |
| return get_worker().address | |
| def test_dask_backend_keywords(loop): | |
| with cluster() as (s, [a, b]): | |
| with Client(s["address"], loop=loop) as client: # noqa: F841 | |
| with parallel_config(backend="dask", workers=a["address"]): | |
| seq = Parallel()(delayed(_worker_address)(i) for i in range(10)) | |
| assert seq == [a["address"]] * 10 | |
| with parallel_config(backend="dask", workers=b["address"]): | |
| seq = Parallel()(delayed(_worker_address)(i) for i in range(10)) | |
| assert seq == [b["address"]] * 10 | |
| def test_scheduler_tasks_cleanup(loop): | |
| with Client(processes=False, loop=loop) as client: | |
| with parallel_config(backend="dask"): | |
| Parallel()(delayed(inc)(i) for i in range(10)) | |
| start = time() | |
| while client.cluster.scheduler.tasks: | |
| sleep(0.01) | |
| assert time() < start + 5 | |
| assert not client.futures | |
| def test_wait_for_workers(cluster_strategy): | |
| cluster = LocalCluster(n_workers=0, processes=False, threads_per_worker=2) | |
| client = Client(cluster) | |
| if cluster_strategy == "adaptive": | |
| cluster.adapt(minimum=0, maximum=2) | |
| elif cluster_strategy == "late_scaling": | |
| # Tell the cluster to start workers but this is a non-blocking call | |
| # and new workers might take time to connect. In this case the Parallel | |
| # call should wait for at least one worker to come up before starting | |
| # to schedule work. | |
| cluster.scale(2) | |
| try: | |
| with parallel_config(backend="dask"): | |
| # The following should wait a bit for at least one worker to | |
| # become available. | |
| Parallel()(delayed(inc)(i) for i in range(10)) | |
| finally: | |
| client.close() | |
| cluster.close() | |
| def test_wait_for_workers_timeout(): | |
| # Start a cluster with 0 worker: | |
| cluster = LocalCluster(n_workers=0, processes=False, threads_per_worker=2) | |
| client = Client(cluster) | |
| try: | |
| with parallel_config(backend="dask", wait_for_workers_timeout=0.1): | |
| # Short timeout: DaskDistributedBackend | |
| msg = "DaskDistributedBackend has no worker after 0.1 seconds." | |
| with pytest.raises(TimeoutError, match=msg): | |
| Parallel()(delayed(inc)(i) for i in range(10)) | |
| with parallel_config(backend="dask", wait_for_workers_timeout=0): | |
| # No timeout: fallback to generic joblib failure: | |
| msg = "DaskDistributedBackend has no active worker" | |
| with pytest.raises(RuntimeError, match=msg): | |
| Parallel()(delayed(inc)(i) for i in range(10)) | |
| finally: | |
| client.close() | |
| cluster.close() | |
| def test_joblib_warning_inside_dask_daemonic_worker(backend): | |
| cluster = LocalCluster(n_workers=2) | |
| client = Client(cluster) | |
| try: | |
| def func_using_joblib_parallel(): | |
| # Somehow trying to check the warning type here (e.g. with | |
| # pytest.warns(UserWarning)) make the test hang. Work-around: | |
| # return the warning record to the client and the warning check is | |
| # done client-side. | |
| with warnings.catch_warnings(record=True) as record: | |
| Parallel(n_jobs=2, backend=backend)(delayed(inc)(i) for i in range(10)) | |
| return record | |
| fut = client.submit(func_using_joblib_parallel) | |
| record = fut.result() | |
| assert len(record) == 1 | |
| warning = record[0].message | |
| assert isinstance(warning, UserWarning) | |
| assert "distributed.worker.daemon" in str(warning) | |
| finally: | |
| client.close(timeout=30) | |
| cluster.close(timeout=30) | |