| |
| """Contains the class of rich logger. |
| |
| This class is based on the module `rich`. Please refer to |
| https://github.com/Textualize/rich for more details. |
| """ |
|
|
| import sys |
| import logging |
| from copy import deepcopy |
| from rich.console import Console |
| from rich.logging import RichHandler |
| from rich.progress import Progress |
| from rich.progress import ProgressColumn |
| from rich.progress import TextColumn |
| from rich.progress import BarColumn |
| from rich.text import Text |
|
|
| from .base_logger import BaseLogger |
|
|
| __all__ = ['RichLogger'] |
|
|
|
|
| def _format_time(seconds): |
| """Formats seconds to readable time string. |
| |
| This function is used to display time in progress bar. |
| """ |
| if not seconds: |
| return '--:--' |
|
|
| seconds = int(seconds) |
| hours, seconds = divmod(seconds, 3600) |
| minutes, seconds = divmod(seconds, 60) |
| if hours: |
| return f'{hours}:{minutes:02d}:{seconds:02d}' |
| return f'{minutes:02d}:{seconds:02d}' |
|
|
|
|
| class TimeColumn(ProgressColumn): |
| """Renders total time, ETA, and speed in progress bar.""" |
|
|
| max_refresh = 0.5 |
|
|
| def render(self, task): |
| elapsed_time = _format_time(task.elapsed) |
| eta = _format_time(task.time_remaining) |
| speed = f'{task.speed:.2f}/s' if task.speed else '?/s' |
| return Text(f'[{elapsed_time}<{eta}, {speed}]', |
| style='progress.remaining') |
|
|
|
|
| class RichLogger(BaseLogger): |
| """Implements the logger based on `rich` module.""" |
|
|
| def __init__(self, |
| logger_name='logger', |
| logfile=None, |
| screen_level=logging.INFO, |
| file_level=logging.DEBUG, |
| indent_space=4, |
| verbose_log=False): |
| super().__init__(logger_name=logger_name, |
| logfile=logfile, |
| screen_level=screen_level, |
| file_level=file_level, |
| indent_space=indent_space, |
| verbose_log=verbose_log) |
|
|
| |
| self.logger = logging.getLogger(self.logger_name) |
| self.logger.propagate = False |
| if self.logger.hasHandlers(): |
| raise SystemExit(f'Logger `{self.logger_name}` has already ' |
| f'existed!\n' |
| f'Please use another name, or otherwise the ' |
| f'messages may be mixed up.') |
|
|
| |
| self.logger.setLevel(logging.DEBUG) |
|
|
| |
| terminal_console = Console( |
| file=sys.stdout, log_time=False, log_path=False) |
| terminal_handler = RichHandler( |
| level=self.screen_level, |
| console=terminal_console, |
| show_time=True, |
| show_level=True, |
| show_path=False, |
| log_time_format='[%Y-%m-%d %H:%M:%S] ') |
| terminal_handler.setFormatter(logging.Formatter('%(message)s')) |
| self.logger.addHandler(terminal_handler) |
|
|
| |
| if self.logfile: |
| |
| self.file_stream = open(self.logfile, 'a') |
| file_console = Console( |
| file=self.file_stream, log_time=False, log_path=False) |
| file_handler = RichHandler( |
| level=self.file_level, |
| console=file_console, |
| show_time=True, |
| show_level=True, |
| show_path=False, |
| log_time_format='[%Y-%m-%d %H:%M:%S] ') |
| file_handler.setFormatter(logging.Formatter('%(message)s')) |
| self.logger.addHandler(file_handler) |
|
|
| self.pbar = None |
| self.pbar_kwargs = {} |
|
|
| def _log(self, message, **kwargs): |
| self.logger.log(message, **kwargs) |
|
|
| def _debug(self, message, **kwargs): |
| self.logger.debug(message, **kwargs) |
|
|
| def _info(self, message, **kwargs): |
| self.logger.info(message, **kwargs) |
|
|
| def _warning(self, message, **kwargs): |
| self.logger.warning(message, **kwargs) |
|
|
| def _error(self, message, **kwargs): |
| self.logger.error(message, **kwargs) |
|
|
| def _exception(self, message, **kwargs): |
| self.logger.exception(message, **kwargs) |
|
|
| def _critical(self, message, **kwargs): |
| self.logger.critical(message, **kwargs) |
|
|
| def _print(self, *messages, **kwargs): |
| for handler in self.logger.handlers: |
| handler.console.print(*messages, **kwargs) |
|
|
| def init_pbar(self, leave=False): |
| assert self.pbar is None |
|
|
| |
| columns = ( |
| TextColumn('[progress.description]{task.description}'), |
| BarColumn(bar_width=None), |
| TextColumn('[progress.percentage]{task.percentage:>5.1f}%'), |
| TimeColumn(), |
| ) |
|
|
| self.pbar = Progress(*columns, |
| console=self.logger.handlers[0].console, |
| transient=not leave, |
| auto_refresh=True, |
| refresh_per_second=10) |
| self.pbar.start() |
|
|
| def add_pbar_task(self, name, total, **kwargs): |
| assert isinstance(self.pbar, Progress) |
| assert isinstance(self.pbar_kwargs, dict) |
| pbar_kwargs = deepcopy(self.pbar_kwargs) |
| pbar_kwargs.update(**kwargs) |
| task_id = self.pbar.add_task(name, total=total, **pbar_kwargs) |
| return task_id |
|
|
| def update_pbar(self, task_id, advance=1): |
| assert isinstance(self.pbar, Progress) |
| if self.pbar.tasks[task_id].finished: |
| if self.pbar.tasks[task_id].stop_time is None: |
| self.pbar.stop_task(task_id) |
| else: |
| self.pbar.update(task_id, advance=advance) |
|
|
| def close_pbar(self): |
| assert isinstance(self.pbar, Progress) |
| self.pbar.stop() |
| self.pbar = None |
| self.pbar_kwargs = {} |
|
|