DeepSolanaCoder
/
DeepSeek-Coder-main
/finetune
/venv
/lib
/python3.12
/site-packages
/datasets
/parallel
/parallel.py
| import contextlib | |
| from multiprocessing import Pool, RLock | |
| from tqdm.auto import tqdm | |
| from ..utils import experimental, logging | |
| logger = logging.get_logger(__name__) | |
| class ParallelBackendConfig: | |
| backend_name = None | |
| def parallel_map(function, iterable, num_proc, batched, batch_size, types, disable_tqdm, desc, single_map_nested_func): | |
| """ | |
| **Experimental.** Apply a function to iterable elements in parallel, where the implementation uses either | |
| multiprocessing.Pool or joblib for parallelization. | |
| Args: | |
| function (`Callable[[Any], Any]`): Function to be applied to `iterable`. | |
| iterable (`list`, `tuple` or `np.ndarray`): Iterable elements to apply function to. | |
| num_proc (`int`): Number of processes (if no backend specified) or jobs (using joblib). | |
| types (`tuple`): Additional types (besides `dict` values) to apply `function` recursively to their elements. | |
| disable_tqdm (`bool`): Whether to disable the tqdm progressbar. | |
| desc (`str`): Prefix for the tqdm progressbar. | |
| single_map_nested_func (`Callable`): Map function that applies `function` to an element from `iterable`. | |
| Takes a tuple of function, data_struct, types, rank, disable_tqdm, desc as input, where data_struct is an | |
| element of `iterable`, and `rank` is used for progress bar. | |
| """ | |
| if ParallelBackendConfig.backend_name is None: | |
| return _map_with_multiprocessing_pool( | |
| function, iterable, num_proc, batched, batch_size, types, disable_tqdm, desc, single_map_nested_func | |
| ) | |
| return _map_with_joblib( | |
| function, iterable, num_proc, batched, batch_size, types, disable_tqdm, desc, single_map_nested_func | |
| ) | |
| def _map_with_multiprocessing_pool( | |
| function, iterable, num_proc, batched, batch_size, types, disable_tqdm, desc, single_map_nested_func | |
| ): | |
| num_proc = num_proc if num_proc <= len(iterable) else len(iterable) | |
| split_kwds = [] # We organize the splits ourselve (contiguous splits) | |
| for index in range(num_proc): | |
| div = len(iterable) // num_proc | |
| mod = len(iterable) % num_proc | |
| start = div * index + min(index, mod) | |
| end = start + div + (1 if index < mod else 0) | |
| split_kwds.append((function, iterable[start:end], batched, batch_size, types, index, disable_tqdm, desc)) | |
| if len(iterable) != sum(len(i[1]) for i in split_kwds): | |
| raise ValueError( | |
| f"Error dividing inputs iterable among processes. " | |
| f"Total number of objects {len(iterable)}, " | |
| f"length: {sum(len(i[1]) for i in split_kwds)}" | |
| ) | |
| logger.info( | |
| f"Spawning {num_proc} processes for {len(iterable)} objects in slices of {[len(i[1]) for i in split_kwds]}" | |
| ) | |
| initargs, initializer = None, None | |
| if not disable_tqdm: | |
| initargs, initializer = (RLock(),), tqdm.set_lock | |
| with Pool(num_proc, initargs=initargs, initializer=initializer) as pool: | |
| mapped = pool.map(single_map_nested_func, split_kwds) | |
| logger.info(f"Finished {num_proc} processes") | |
| mapped = [obj for proc_res in mapped for obj in proc_res] | |
| logger.info(f"Unpacked {len(mapped)} objects") | |
| return mapped | |
| def _map_with_joblib( | |
| function, iterable, num_proc, batched, batch_size, types, disable_tqdm, desc, single_map_nested_func | |
| ): | |
| # progress bar is not yet supported for _map_with_joblib, because tqdm couldn't accurately be applied to joblib, | |
| # and it requires monkey-patching joblib internal classes which is subject to change | |
| import joblib | |
| with joblib.parallel_backend(ParallelBackendConfig.backend_name, n_jobs=num_proc): | |
| return joblib.Parallel()( | |
| joblib.delayed(single_map_nested_func)((function, obj, batched, batch_size, types, None, True, None)) | |
| for obj in iterable | |
| ) | |
| def parallel_backend(backend_name: str): | |
| """ | |
| **Experimental.** Configures the parallel backend for parallelized dataset loading, which uses the parallelization | |
| implemented by joblib. | |
| Args: | |
| backend_name (str): Name of backend for parallelization implementation, has to be supported by joblib. | |
| Example usage: | |
| ```py | |
| with parallel_backend('spark'): | |
| dataset = load_dataset(..., num_proc=2) | |
| ``` | |
| """ | |
| ParallelBackendConfig.backend_name = backend_name | |
| if backend_name == "spark": | |
| from joblibspark import register_spark | |
| register_spark() | |
| # TODO: call create_cache_and_write_probe if "download" in steps | |
| # TODO: raise NotImplementedError when Dataset.map etc is called | |
| try: | |
| yield | |
| finally: | |
| ParallelBackendConfig.backend_name = None | |