extensions
/
microsoftexcel-controlnet
/annotator
/mmpkg
/mmcv
/runner
/hooks
/logger
/tensorboard.py
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import os.path as osp | |
| from annotator.mmpkg.mmcv.utils import TORCH_VERSION, digit_version | |
| from ...dist_utils import master_only | |
| from ..hook import HOOKS | |
| from .base import LoggerHook | |
| class TensorboardLoggerHook(LoggerHook): | |
| def __init__(self, | |
| log_dir=None, | |
| interval=10, | |
| ignore_last=True, | |
| reset_flag=False, | |
| by_epoch=True): | |
| super(TensorboardLoggerHook, self).__init__(interval, ignore_last, | |
| reset_flag, by_epoch) | |
| self.log_dir = log_dir | |
| def before_run(self, runner): | |
| super(TensorboardLoggerHook, self).before_run(runner) | |
| if (TORCH_VERSION == 'parrots' | |
| or digit_version(TORCH_VERSION) < digit_version('1.1')): | |
| try: | |
| from tensorboardX import SummaryWriter | |
| except ImportError: | |
| raise ImportError('Please install tensorboardX to use ' | |
| 'TensorboardLoggerHook.') | |
| else: | |
| try: | |
| from torch.utils.tensorboard import SummaryWriter | |
| except ImportError: | |
| raise ImportError( | |
| 'Please run "pip install future tensorboard" to install ' | |
| 'the dependencies to use torch.utils.tensorboard ' | |
| '(applicable to PyTorch 1.1 or higher)') | |
| if self.log_dir is None: | |
| self.log_dir = osp.join(runner.work_dir, 'tf_logs') | |
| self.writer = SummaryWriter(self.log_dir) | |
| def log(self, runner): | |
| tags = self.get_loggable_tags(runner, allow_text=True) | |
| for tag, val in tags.items(): | |
| if isinstance(val, str): | |
| self.writer.add_text(tag, val, self.get_iter(runner)) | |
| else: | |
| self.writer.add_scalar(tag, val, self.get_iter(runner)) | |
| def after_run(self, runner): | |
| self.writer.close() | |