| | import contextlib |
| | from typing import Any, Iterable, Iterator, Optional |
| |
|
| | try: |
| | from tqdm import tqdm |
| | except ImportError: |
| | tqdm = None |
| |
|
| | try: |
| | from ray.experimental.tqdm_ray import tqdm as ray_tqdm |
| | except: |
| | ray_tqdm = None |
| |
|
| | |
| | _current_progress_type = "none" |
| | _is_progress_bar_active = False |
| |
|
| |
|
| | class DummyProgressBar: |
| | """A no-op progress bar that mimics tqdm interface""" |
| |
|
| | def __init__(self, iterable=None, **kwargs): |
| | self.iterable = iterable |
| |
|
| | def __iter__(self): |
| | return iter(self.iterable) |
| |
|
| | def update(self, n=1): |
| | pass |
| |
|
| | def close(self): |
| | pass |
| |
|
| | def set_description(self, desc): |
| | pass |
| |
|
| |
|
| | def get_new_progress_bar(iterable: Optional[Iterable] = None, **kwargs) -> Any: |
| | if not _is_progress_bar_active: |
| | return DummyProgressBar(iterable=iterable, **kwargs) |
| |
|
| | if _current_progress_type == "tqdm": |
| | if tqdm is None: |
| | raise ImportError("tqdm is required but not installed. Please install tqdm to use the tqdm progress bar.") |
| | return tqdm(iterable=iterable, **kwargs) |
| | elif _current_progress_type == "ray_tqdm": |
| | if ray_tqdm is None: |
| | raise ImportError("ray is required but not installed. Please install ray to use the ray_tqdm progress bar.") |
| | return ray_tqdm(iterable=iterable, **kwargs) |
| | return DummyProgressBar(iterable=iterable, **kwargs) |
| |
|
| |
|
| | @contextlib.contextmanager |
| | def progress_bar(type: str = "none", enabled=True): |
| | """ |
| | Context manager for setting progress bar type and options. |
| | |
| | Args: |
| | type: Type of progress bar ("none" or "tqdm") |
| | **options: Options to pass to the progress bar (e.g., total, desc) |
| | |
| | Raises: |
| | ValueError: If progress bar type is invalid |
| | RuntimeError: If progress bars are nested |
| | |
| | Example: |
| | with progress_bar(type="tqdm", total=100): |
| | for i in get_new_progress_bar(range(100)): |
| | process(i) |
| | """ |
| | if type not in ("none", "tqdm", "ray_tqdm"): |
| | raise ValueError("Progress bar type must be 'none' or 'tqdm' or 'ray_tqdm'") |
| | if not enabled: |
| | type = "none" |
| | global _current_progress_type, _is_progress_bar_active |
| |
|
| | if _is_progress_bar_active: |
| | raise RuntimeError("Nested progress bars are not supported") |
| |
|
| | _is_progress_bar_active = True |
| | _current_progress_type = type |
| |
|
| | try: |
| | yield |
| | finally: |
| | _is_progress_bar_active = False |
| | _current_progress_type = "none" |
| |
|