| |
| from multiprocessing import Pool |
| from typing import Callable, Iterable, Sized |
|
|
| from rich.progress import (BarColumn, MofNCompleteColumn, Progress, Task, |
| TaskProgressColumn, TextColumn, TimeRemainingColumn) |
| from rich.text import Text |
|
|
|
|
| class _Worker: |
| """Function wrapper for ``track_progress_rich``""" |
|
|
| def __init__(self, func) -> None: |
| self.func = func |
|
|
| def __call__(self, inputs): |
| inputs, idx = inputs |
| if not isinstance(inputs, (tuple, list)): |
| inputs = (inputs, ) |
|
|
| return self.func(*inputs), idx |
|
|
|
|
| class _SkipFirstTimeRemainingColumn(TimeRemainingColumn): |
| """Skip calculating remaining time for the first few times. |
| |
| Args: |
| skip_times (int): The number of times to skip. Defaults to 0. |
| """ |
|
|
| def __init__(self, *args, skip_times=0, **kwargs): |
| super().__init__(*args, **kwargs) |
| self.skip_times = skip_times |
|
|
| def render(self, task: Task) -> Text: |
| """Show time remaining.""" |
| if task.completed <= self.skip_times: |
| return Text('-:--:--', style='progress.remaining') |
| return super().render(task) |
|
|
|
|
| def _tasks_with_index(tasks): |
| """Add index to tasks.""" |
| for idx, task in enumerate(tasks): |
| yield task, idx |
|
|
|
|
| def track_progress_rich(func: Callable, |
| tasks: Iterable = tuple(), |
| task_num: int = None, |
| nproc: int = 1, |
| chunksize: int = 1, |
| description: str = 'Processing', |
| color: str = 'blue') -> list: |
| """Track the progress of parallel task execution with a progress bar. The |
| built-in :mod:`multiprocessing` module is used for process pools and tasks |
| are done with :func:`Pool.map` or :func:`Pool.imap_unordered`. |
| |
| Args: |
| func (callable): The function to be applied to each task. |
| tasks (Iterable or Sized): A tuple of tasks. There are several cases |
| for different format tasks: |
| - When ``func`` accepts no arguments: tasks should be an empty |
| tuple, and ``task_num`` must be specified. |
| - When ``func`` accepts only one argument: tasks should be a tuple |
| containing the argument. |
| - When ``func`` accepts multiple arguments: tasks should be a |
| tuple, with each element representing a set of arguments. |
| If an element is a ``dict``, it will be parsed as a set of |
| keyword-only arguments. |
| Defaults to an empty tuple. |
| task_num (int, optional): If ``tasks`` is an iterator which does not |
| have length, the number of tasks can be provided by ``task_num``. |
| Defaults to None. |
| nproc (int): Process (worker) number, if nuproc is 1, |
| use single process. Defaults to 1. |
| chunksize (int): Refer to :class:`multiprocessing.Pool` for details. |
| Defaults to 1. |
| description (str): The description of progress bar. |
| Defaults to "Process". |
| color (str): The color of progress bar. Defaults to "blue". |
| |
| Examples: |
| >>> import time |
| |
| >>> def func(x): |
| ... time.sleep(1) |
| ... return x**2 |
| >>> track_progress_rich(func, range(10), nproc=2) |
| |
| Returns: |
| list: The task results. |
| """ |
| if not callable(func): |
| raise TypeError('func must be a callable object') |
| if not isinstance(tasks, Iterable): |
| raise TypeError( |
| f'tasks must be an iterable object, but got {type(tasks)}') |
| if isinstance(tasks, Sized): |
| if len(tasks) == 0: |
| if task_num is None: |
| raise ValueError('If tasks is an empty iterable, ' |
| 'task_num must be set') |
| else: |
| tasks = tuple(tuple() for _ in range(task_num)) |
| else: |
| if task_num is not None and task_num != len(tasks): |
| raise ValueError('task_num does not match the length of tasks') |
| task_num = len(tasks) |
|
|
| if nproc <= 0: |
| raise ValueError('nproc must be a positive number') |
|
|
| skip_times = nproc * chunksize if nproc > 1 else 0 |
| prog_bar = Progress( |
| TextColumn('{task.description}'), |
| BarColumn(), |
| _SkipFirstTimeRemainingColumn(skip_times=skip_times), |
| MofNCompleteColumn(), |
| TaskProgressColumn(show_speed=True), |
| ) |
|
|
| worker = _Worker(func) |
| task_id = prog_bar.add_task( |
| total=task_num, color=color, description=description) |
| tasks = _tasks_with_index(tasks) |
|
|
| |
| with prog_bar: |
| if nproc == 1: |
| results = [] |
| for task in tasks: |
| results.append(worker(task)[0]) |
| prog_bar.update(task_id, advance=1, refresh=True) |
| else: |
| with Pool(nproc) as pool: |
| results = [] |
| unordered_results = [] |
| gen = pool.imap_unordered(worker, tasks, chunksize) |
| try: |
| for result in gen: |
| result, idx = result |
| unordered_results.append((result, idx)) |
| results.append(None) |
| prog_bar.update(task_id, advance=1, refresh=True) |
| except Exception as e: |
| prog_bar.stop() |
| raise e |
| for result, idx in unordered_results: |
| results[idx] = result |
| return results |
|
|