| |
| |
| |
| |
|
|
| """ |
| Wrapper around various loggers and progress bars (e.g., tqdm). |
| """ |
|
|
| import atexit |
| import json |
| import logging |
| import os |
| import sys |
| from collections import OrderedDict |
| from contextlib import contextmanager |
| from numbers import Number |
| from typing import Optional |
|
|
| import torch |
|
|
| from .meters import AverageMeter, StopwatchMeter, TimeMeter |
|
|
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| def progress_bar( |
| iterator, |
| log_format: Optional[str] = None, |
| log_interval: int = 100, |
| epoch: Optional[int] = None, |
| prefix: Optional[str] = None, |
| tensorboard_logdir: Optional[str] = None, |
| default_log_format: str = "tqdm", |
| ): |
| if log_format is None: |
| log_format = default_log_format |
| if log_format == "tqdm" and not sys.stderr.isatty(): |
| log_format = "simple" |
|
|
| if log_format == "json": |
| bar = JsonProgressBar(iterator, epoch, prefix, log_interval) |
| elif log_format == "none": |
| bar = NoopProgressBar(iterator, epoch, prefix) |
| elif log_format == "simple": |
| bar = SimpleProgressBar(iterator, epoch, prefix, log_interval) |
| elif log_format == "tqdm": |
| bar = TqdmProgressBar(iterator, epoch, prefix) |
| else: |
| raise ValueError("Unknown log format: {}".format(log_format)) |
|
|
| if tensorboard_logdir: |
| try: |
| |
| import palaas |
| from .fb_tbmf_wrapper import FbTbmfWrapper |
|
|
| bar = FbTbmfWrapper(bar, log_interval) |
| except ImportError: |
| bar = TensorboardProgressBarWrapper(bar, tensorboard_logdir) |
|
|
| return bar |
|
|
|
|
| def build_progress_bar( |
| args, |
| iterator, |
| epoch: Optional[int] = None, |
| prefix: Optional[str] = None, |
| default: str = "tqdm", |
| no_progress_bar: str = "none", |
| ): |
| """Legacy wrapper that takes an argparse.Namespace.""" |
| if getattr(args, "no_progress_bar", False): |
| default = no_progress_bar |
| if getattr(args, "distributed_rank", 0) == 0: |
| tensorboard_logdir = getattr(args, "tensorboard_logdir", None) |
| else: |
| tensorboard_logdir = None |
| return progress_bar( |
| iterator, |
| log_format=args.log_format, |
| log_interval=args.log_interval, |
| epoch=epoch, |
| prefix=prefix, |
| tensorboard_logdir=tensorboard_logdir, |
| default_log_format=default, |
| ) |
|
|
|
|
| def format_stat(stat): |
| if isinstance(stat, Number): |
| stat = "{:g}".format(stat) |
| elif isinstance(stat, AverageMeter): |
| stat = "{:.3f}".format(stat.avg) |
| elif isinstance(stat, TimeMeter): |
| stat = "{:g}".format(round(stat.avg)) |
| elif isinstance(stat, StopwatchMeter): |
| stat = "{:g}".format(round(stat.sum)) |
| elif torch.is_tensor(stat): |
| stat = stat.tolist() |
| return stat |
|
|
|
|
| class BaseProgressBar(object): |
| """Abstract class for progress bars.""" |
|
|
| def __init__(self, iterable, epoch=None, prefix=None): |
| self.iterable = iterable |
| self.n = getattr(iterable, "n", 0) |
| self.epoch = epoch |
| self.prefix = "" |
| if epoch is not None: |
| self.prefix += "epoch {:03d}".format(epoch) |
| if prefix is not None: |
| self.prefix += " | {}".format(prefix) |
|
|
| def __len__(self): |
| return len(self.iterable) |
|
|
| def __enter__(self): |
| return self |
|
|
| def __exit__(self, *exc): |
| return False |
|
|
| def __iter__(self): |
| raise NotImplementedError |
|
|
| def log(self, stats, tag=None, step=None): |
| """Log intermediate stats according to log_interval.""" |
| raise NotImplementedError |
|
|
| def print(self, stats, tag=None, step=None): |
| """Print end-of-epoch stats.""" |
| raise NotImplementedError |
|
|
| def _str_commas(self, stats): |
| return ", ".join(key + "=" + stats[key].strip() for key in stats.keys()) |
|
|
| def _str_pipes(self, stats): |
| return " | ".join(key + " " + stats[key].strip() for key in stats.keys()) |
|
|
| def _format_stats(self, stats): |
| postfix = OrderedDict(stats) |
| |
| for key in postfix.keys(): |
| postfix[key] = str(format_stat(postfix[key])) |
| return postfix |
|
|
|
|
| @contextmanager |
| def rename_logger(logger, new_name): |
| old_name = logger.name |
| if new_name is not None: |
| logger.name = new_name |
| yield logger |
| logger.name = old_name |
|
|
|
|
| class JsonProgressBar(BaseProgressBar): |
| """Log output in JSON format.""" |
|
|
| def __init__(self, iterable, epoch=None, prefix=None, log_interval=1000): |
| super().__init__(iterable, epoch, prefix) |
| self.log_interval = log_interval |
| self.i = None |
| self.size = None |
|
|
| def __iter__(self): |
| self.size = len(self.iterable) |
| for i, obj in enumerate(self.iterable, start=self.n): |
| self.i = i |
| yield obj |
|
|
| def log(self, stats, tag=None, step=None): |
| """Log intermediate stats according to log_interval.""" |
| step = step or self.i or 0 |
| if step > 0 and self.log_interval is not None and step % self.log_interval == 0: |
| update = ( |
| self.epoch - 1 + (self.i + 1) / float(self.size) |
| if self.epoch is not None |
| else None |
| ) |
| stats = self._format_stats(stats, epoch=self.epoch, update=update) |
| with rename_logger(logger, tag): |
| logger.info(json.dumps(stats)) |
|
|
| def print(self, stats, tag=None, step=None): |
| """Print end-of-epoch stats.""" |
| self.stats = stats |
| if tag is not None: |
| self.stats = OrderedDict( |
| [(tag + "_" + k, v) for k, v in self.stats.items()] |
| ) |
| stats = self._format_stats(self.stats, epoch=self.epoch) |
| with rename_logger(logger, tag): |
| logger.info(json.dumps(stats)) |
|
|
| def _format_stats(self, stats, epoch=None, update=None): |
| postfix = OrderedDict() |
| if epoch is not None: |
| postfix["epoch"] = epoch |
| if update is not None: |
| postfix["update"] = round(update, 3) |
| |
| for key in stats.keys(): |
| postfix[key] = format_stat(stats[key]) |
| return postfix |
|
|
|
|
| class NoopProgressBar(BaseProgressBar): |
| """No logging.""" |
|
|
| def __init__(self, iterable, epoch=None, prefix=None): |
| super().__init__(iterable, epoch, prefix) |
|
|
| def __iter__(self): |
| for obj in self.iterable: |
| yield obj |
|
|
| def log(self, stats, tag=None, step=None): |
| """Log intermediate stats according to log_interval.""" |
| pass |
|
|
| def print(self, stats, tag=None, step=None): |
| """Print end-of-epoch stats.""" |
| pass |
|
|
|
|
| class SimpleProgressBar(BaseProgressBar): |
| """A minimal logger for non-TTY environments.""" |
|
|
| def __init__(self, iterable, epoch=None, prefix=None, log_interval=1000): |
| super().__init__(iterable, epoch, prefix) |
| self.log_interval = log_interval |
| self.i = None |
| self.size = None |
|
|
| def __iter__(self): |
| self.size = len(self.iterable) |
| for i, obj in enumerate(self.iterable, start=self.n): |
| self.i = i |
| yield obj |
|
|
| def log(self, stats, tag=None, step=None): |
| """Log intermediate stats according to log_interval.""" |
| step = step or self.i or 0 |
| if step > 0 and self.log_interval is not None and step % self.log_interval == 0: |
| stats = self._format_stats(stats) |
| postfix = self._str_commas(stats) |
| with rename_logger(logger, tag): |
| logger.info( |
| "{}: {:5d} / {:d} {}".format( |
| self.prefix, self.i + 1, self.size, postfix |
| ) |
| ) |
|
|
| def print(self, stats, tag=None, step=None): |
| """Print end-of-epoch stats.""" |
| postfix = self._str_pipes(self._format_stats(stats)) |
| with rename_logger(logger, tag): |
| logger.info("{} | {}".format(self.prefix, postfix)) |
|
|
|
|
| class TqdmProgressBar(BaseProgressBar): |
| """Log to tqdm.""" |
|
|
| def __init__(self, iterable, epoch=None, prefix=None): |
| super().__init__(iterable, epoch, prefix) |
| from tqdm import tqdm |
|
|
| self.tqdm = tqdm( |
| iterable, |
| self.prefix, |
| leave=False, |
| disable=(logger.getEffectiveLevel() > logging.INFO), |
| ) |
|
|
| def __iter__(self): |
| return iter(self.tqdm) |
|
|
| def log(self, stats, tag=None, step=None): |
| """Log intermediate stats according to log_interval.""" |
| self.tqdm.set_postfix(self._format_stats(stats), refresh=False) |
|
|
| def print(self, stats, tag=None, step=None): |
| """Print end-of-epoch stats.""" |
| postfix = self._str_pipes(self._format_stats(stats)) |
| with rename_logger(logger, tag): |
| logger.info("{} | {}".format(self.prefix, postfix)) |
|
|
|
|
| try: |
| _tensorboard_writers = {} |
| from tensorboardX import SummaryWriter |
| except ImportError: |
| SummaryWriter = None |
|
|
|
|
| def _close_writers(): |
| for w in _tensorboard_writers.values(): |
| w.close() |
|
|
|
|
| atexit.register(_close_writers) |
|
|
|
|
| class TensorboardProgressBarWrapper(BaseProgressBar): |
| """Log to tensorboard.""" |
|
|
| def __init__(self, wrapped_bar, tensorboard_logdir): |
| self.wrapped_bar = wrapped_bar |
| self.tensorboard_logdir = tensorboard_logdir |
|
|
| if SummaryWriter is None: |
| logger.warning( |
| "tensorboard not found, please install with: pip install tensorboardX" |
| ) |
|
|
| def _writer(self, key): |
| if SummaryWriter is None: |
| return None |
| _writers = _tensorboard_writers |
| if key not in _writers: |
| _writers[key] = SummaryWriter(os.path.join(self.tensorboard_logdir, key)) |
| _writers[key].add_text("sys.argv", " ".join(sys.argv)) |
| return _writers[key] |
|
|
| def __iter__(self): |
| return iter(self.wrapped_bar) |
|
|
| def log(self, stats, tag=None, step=None): |
| """Log intermediate stats to tensorboard.""" |
| self._log_to_tensorboard(stats, tag, step) |
| self.wrapped_bar.log(stats, tag=tag, step=step) |
|
|
| def print(self, stats, tag=None, step=None): |
| """Print end-of-epoch stats.""" |
| self._log_to_tensorboard(stats, tag, step) |
| self.wrapped_bar.print(stats, tag=tag, step=step) |
|
|
| def _log_to_tensorboard(self, stats, tag=None, step=None): |
| writer = self._writer(tag or "") |
| if writer is None: |
| return |
| if step is None: |
| step = stats["num_updates"] |
| for key in stats.keys() - {"num_updates"}: |
| if isinstance(stats[key], AverageMeter): |
| writer.add_scalar(key, stats[key].val, step) |
| elif isinstance(stats[key], Number): |
| writer.add_scalar(key, stats[key], step) |
| writer.flush() |
|
|