|
|
from rich import print |
|
|
from dataclasses import dataclass |
|
|
from pytorch_lightning.utilities import rank_zero_only |
|
|
from typing import Union |
|
|
from pytorch_lightning.callbacks.progress.rich_progress import * |
|
|
from rich.console import Console, RenderableType |
|
|
from rich.progress_bar import ProgressBar |
|
|
from rich.style import Style |
|
|
from rich.text import Text |
|
|
from rich.progress import ( |
|
|
BarColumn, |
|
|
DownloadColumn, |
|
|
Progress, |
|
|
TaskID, |
|
|
TextColumn, |
|
|
TimeRemainingColumn, |
|
|
TransferSpeedColumn, |
|
|
ProgressColumn |
|
|
) |
|
|
from rich import print, reconfigure |
|
|
|
|
|
@rank_zero_only |
|
|
def print_only(message: str): |
|
|
print(message) |
|
|
|
|
|
@dataclass |
|
|
class RichProgressBarTheme: |
|
|
"""Styles to associate to different base components. |
|
|
|
|
|
Args: |
|
|
description: Style for the progress bar description. For eg., Epoch x, Testing, etc. |
|
|
progress_bar: Style for the bar in progress. |
|
|
progress_bar_finished: Style for the finished progress bar. |
|
|
progress_bar_pulse: Style for the progress bar when `IterableDataset` is being processed. |
|
|
batch_progress: Style for the progress tracker (i.e 10/50 batches completed). |
|
|
time: Style for the processed time and estimate time remaining. |
|
|
processing_speed: Style for the speed of the batches being processed. |
|
|
metrics: Style for the metrics |
|
|
|
|
|
https://rich.readthedocs.io/en/stable/style.html |
|
|
""" |
|
|
|
|
|
description: Union[str, Style] = "#FF4500" |
|
|
progress_bar: Union[str, Style] = "#f92672" |
|
|
progress_bar_finished: Union[str, Style] = "#b7cc8a" |
|
|
progress_bar_pulse: Union[str, Style] = "#f92672" |
|
|
batch_progress: Union[str, Style] = "#fc608a" |
|
|
time: Union[str, Style] = "#45ada2" |
|
|
processing_speed: Union[str, Style] = "#DC143C" |
|
|
metrics: Union[str, Style] = "#228B22" |
|
|
|
|
|
class BatchesProcessedColumn(ProgressColumn): |
|
|
def __init__(self, style: Union[str, Style]): |
|
|
self.style = style |
|
|
super().__init__() |
|
|
|
|
|
def render(self, task) -> RenderableType: |
|
|
total = task.total if task.total != float("inf") else "--" |
|
|
return Text(f"{int(task.completed)}/{int(total)}", style=self.style) |
|
|
|
|
|
class MyMetricsTextColumn(ProgressColumn): |
|
|
"""A column containing text.""" |
|
|
|
|
|
def __init__(self, style): |
|
|
self._tasks = {} |
|
|
self._current_task_id = 0 |
|
|
self._metrics = {} |
|
|
self._style = style |
|
|
super().__init__() |
|
|
|
|
|
def update(self, metrics): |
|
|
|
|
|
|
|
|
|
|
|
self._metrics = metrics |
|
|
|
|
|
def render(self, task) -> Text: |
|
|
text = "" |
|
|
for k, v in self._metrics.items(): |
|
|
text += f"{k}: {round(v, 3) if isinstance(v, float) else v} " |
|
|
return Text(text, justify="left", style=self._style) |
|
|
|
|
|
class MyRichProgressBar(RichProgressBar): |
|
|
"""A progress bar prints metrics at the end of each epoch |
|
|
""" |
|
|
|
|
|
def _init_progress(self, trainer): |
|
|
if self.is_enabled and (self.progress is None or self._progress_stopped): |
|
|
self._reset_progress_bar_ids() |
|
|
reconfigure(**self._console_kwargs) |
|
|
|
|
|
self._console: Console = Console(force_terminal=True) |
|
|
self._console.clear_live() |
|
|
self._metric_component = MetricsTextColumn(trainer, self.theme.metrics) |
|
|
self.progress = CustomProgress( |
|
|
*self.configure_columns(trainer), |
|
|
self._metric_component, |
|
|
auto_refresh=False, |
|
|
disable=self.is_disabled, |
|
|
console=self._console, |
|
|
) |
|
|
self.progress.start() |
|
|
|
|
|
self._progress_stopped = False |