| import os |
| import traceback |
| from functools import partial |
| from tqdm import tqdm |
|
|
|
|
| def chunked_worker(worker_id, args_queue=None, results_queue=None, init_ctx_func=None): |
| ctx = init_ctx_func(worker_id) if init_ctx_func is not None else None |
| while True: |
| args = args_queue.get() |
| if args == '<KILL>': |
| return |
| job_idx, map_func, arg = args |
| try: |
| map_func_ = partial(map_func, ctx=ctx) if ctx is not None else map_func |
| if isinstance(arg, dict): |
| res = map_func_(**arg) |
| elif isinstance(arg, (list, tuple)): |
| res = map_func_(*arg) |
| else: |
| res = map_func_(arg) |
| results_queue.put((job_idx, res)) |
| except: |
| traceback.print_exc() |
| results_queue.put((job_idx, None)) |
|
|
|
|
| class MultiprocessManager: |
| def __init__(self, num_workers=None, init_ctx_func=None, multithread=False, queue_max=-1): |
| if multithread: |
| from multiprocessing.dummy import Queue, Process |
| else: |
| from multiprocessing import Queue, Process |
| if num_workers is None: |
| num_workers = int(os.getenv('N_PROC', os.cpu_count())) |
| self.num_workers = num_workers |
| self.results_queue = Queue(maxsize=-1) |
| self.jobs_pending = [] |
| self.args_queue = Queue(maxsize=queue_max) |
| self.workers = [] |
| self.total_jobs = 0 |
| self.multithread = multithread |
| for i in range(num_workers): |
| if multithread: |
| p = Process(target=chunked_worker, |
| args=(i, self.args_queue, self.results_queue, init_ctx_func)) |
| else: |
| p = Process(target=chunked_worker, |
| args=(i, self.args_queue, self.results_queue, init_ctx_func), |
| daemon=True) |
| self.workers.append(p) |
| p.start() |
|
|
| def add_job(self, func, args): |
| if not self.args_queue.full(): |
| self.args_queue.put((self.total_jobs, func, args)) |
| else: |
| self.jobs_pending.append((self.total_jobs, func, args)) |
| self.total_jobs += 1 |
|
|
| def get_results(self): |
| self.n_finished = 0 |
| while self.n_finished < self.total_jobs: |
| while len(self.jobs_pending) > 0 and not self.args_queue.full(): |
| self.args_queue.put(self.jobs_pending[0]) |
| self.jobs_pending = self.jobs_pending[1:] |
| job_id, res = self.results_queue.get() |
| yield job_id, res |
| self.n_finished += 1 |
| for w in range(self.num_workers): |
| self.args_queue.put("<KILL>") |
| for w in self.workers: |
| w.join() |
|
|
| def close(self): |
| if not self.multithread: |
| for w in self.workers: |
| w.terminate() |
|
|
| def __len__(self): |
| return self.total_jobs |
|
|
|
|
| def multiprocess_run_tqdm(map_func, args, num_workers=None, ordered=True, init_ctx_func=None, |
| multithread=False, queue_max=-1, desc=None): |
| for i, res in tqdm( |
| multiprocess_run(map_func, args, num_workers, ordered, init_ctx_func, multithread, |
| queue_max=queue_max), |
| total=len(args), desc=desc): |
| yield i, res |
|
|
|
|
| def multiprocess_run(map_func, args, num_workers=None, ordered=True, init_ctx_func=None, multithread=False, |
| queue_max=-1): |
| """ |
| Multiprocessing running chunked jobs. |
| |
| Examples: |
| >>> for res in tqdm(multiprocess_run(job_func, args): |
| >>> print(res) |
| |
| :param map_func: |
| :param args: |
| :param num_workers: |
| :param ordered: |
| :param init_ctx_func: |
| :param q_max_size: |
| :param multithread: |
| :return: |
| """ |
| if num_workers is None: |
| num_workers = int(os.getenv('N_PROC', os.cpu_count())) |
| |
| manager = MultiprocessManager(num_workers, init_ctx_func, multithread, queue_max=queue_max) |
| for arg in args: |
| manager.add_job(map_func, arg) |
| if ordered: |
| n_jobs = len(args) |
| results = ['<WAIT>' for _ in range(n_jobs)] |
| i_now = 0 |
| for job_i, res in manager.get_results(): |
| results[job_i] = res |
| while i_now < n_jobs and (not isinstance(results[i_now], str) or results[i_now] != '<WAIT>'): |
| yield i_now, results[i_now] |
| results[i_now] = None |
| i_now += 1 |
| else: |
| for job_i, res in manager.get_results(): |
| yield job_i, res |
| manager.close() |
|
|