| | """ |
| | Thin wrappers around `concurrent.futures`. |
| | """ |
| | from contextlib import contextmanager |
| | from operator import length_hint |
| | from os import cpu_count |
| |
|
| | from ..auto import tqdm as tqdm_auto |
| | from ..std import TqdmWarning |
| |
|
| | __author__ = {"github.com/": ["casperdcl"]} |
| | __all__ = ['thread_map', 'process_map'] |
| |
|
| |
|
| | @contextmanager |
| | def ensure_lock(tqdm_class, lock_name=""): |
| | """get (create if necessary) and then restore `tqdm_class`'s lock""" |
| | old_lock = getattr(tqdm_class, '_lock', None) |
| | lock = old_lock or tqdm_class.get_lock() |
| | lock = getattr(lock, lock_name, lock) |
| | tqdm_class.set_lock(lock) |
| | yield lock |
| | if old_lock is None: |
| | del tqdm_class._lock |
| | else: |
| | tqdm_class.set_lock(old_lock) |
| |
|
| |
|
| | def _executor_map(PoolExecutor, fn, *iterables, **tqdm_kwargs): |
| | """ |
| | Implementation of `thread_map` and `process_map`. |
| | |
| | Parameters |
| | ---------- |
| | tqdm_class : [default: tqdm.auto.tqdm]. |
| | max_workers : [default: min(32, cpu_count() + 4)]. |
| | chunksize : [default: 1]. |
| | lock_name : [default: "":str]. |
| | """ |
| | kwargs = tqdm_kwargs.copy() |
| | if "total" not in kwargs: |
| | kwargs["total"] = length_hint(iterables[0]) |
| | tqdm_class = kwargs.pop("tqdm_class", tqdm_auto) |
| | max_workers = kwargs.pop("max_workers", min(32, cpu_count() + 4)) |
| | chunksize = kwargs.pop("chunksize", 1) |
| | lock_name = kwargs.pop("lock_name", "") |
| | with ensure_lock(tqdm_class, lock_name=lock_name) as lk: |
| | |
| | with PoolExecutor(max_workers=max_workers, initializer=tqdm_class.set_lock, |
| | initargs=(lk,)) as ex: |
| | return list(tqdm_class(ex.map(fn, *iterables, chunksize=chunksize), **kwargs)) |
| |
|
| |
|
| | def thread_map(fn, *iterables, **tqdm_kwargs): |
| | """ |
| | Equivalent of `list(map(fn, *iterables))` |
| | driven by `concurrent.futures.ThreadPoolExecutor`. |
| | |
| | Parameters |
| | ---------- |
| | tqdm_class : optional |
| | `tqdm` class to use for bars [default: tqdm.auto.tqdm]. |
| | max_workers : int, optional |
| | Maximum number of workers to spawn; passed to |
| | `concurrent.futures.ThreadPoolExecutor.__init__`. |
| | [default: max(32, cpu_count() + 4)]. |
| | """ |
| | from concurrent.futures import ThreadPoolExecutor |
| | return _executor_map(ThreadPoolExecutor, fn, *iterables, **tqdm_kwargs) |
| |
|
| |
|
| | def process_map(fn, *iterables, **tqdm_kwargs): |
| | """ |
| | Equivalent of `list(map(fn, *iterables))` |
| | driven by `concurrent.futures.ProcessPoolExecutor`. |
| | |
| | Parameters |
| | ---------- |
| | tqdm_class : optional |
| | `tqdm` class to use for bars [default: tqdm.auto.tqdm]. |
| | max_workers : int, optional |
| | Maximum number of workers to spawn; passed to |
| | `concurrent.futures.ProcessPoolExecutor.__init__`. |
| | [default: min(32, cpu_count() + 4)]. |
| | chunksize : int, optional |
| | Size of chunks sent to worker processes; passed to |
| | `concurrent.futures.ProcessPoolExecutor.map`. [default: 1]. |
| | lock_name : str, optional |
| | Member of `tqdm_class.get_lock()` to use [default: mp_lock]. |
| | """ |
| | from concurrent.futures import ProcessPoolExecutor |
| | if iterables and "chunksize" not in tqdm_kwargs: |
| | |
| | |
| | longest_iterable_len = max(map(length_hint, iterables)) |
| | if longest_iterable_len > 1000: |
| | from warnings import warn |
| | warn("Iterable length %d > 1000 but `chunksize` is not set." |
| | " This may seriously degrade multiprocess performance." |
| | " Set `chunksize=1` or more." % longest_iterable_len, |
| | TqdmWarning, stacklevel=2) |
| | if "lock_name" not in tqdm_kwargs: |
| | tqdm_kwargs = tqdm_kwargs.copy() |
| | tqdm_kwargs["lock_name"] = "mp_lock" |
| | return _executor_map(ProcessPoolExecutor, fn, *iterables, **tqdm_kwargs) |
| |
|