| | |
| | from ...dist_utils import master_only |
| | from ..hook import HOOKS |
| | from .base import LoggerHook |
| |
|
| |
|
| | @HOOKS.register_module() |
| | class WandbLoggerHook(LoggerHook): |
| |
|
| | def __init__(self, |
| | init_kwargs=None, |
| | interval=10, |
| | ignore_last=True, |
| | reset_flag=False, |
| | commit=True, |
| | by_epoch=True, |
| | with_step=True): |
| | super(WandbLoggerHook, self).__init__(interval, ignore_last, |
| | reset_flag, by_epoch) |
| | self.import_wandb() |
| | self.init_kwargs = init_kwargs |
| | self.commit = commit |
| | self.with_step = with_step |
| |
|
| | def import_wandb(self): |
| | try: |
| | import wandb |
| | except ImportError: |
| | raise ImportError( |
| | 'Please run "pip install wandb" to install wandb') |
| | self.wandb = wandb |
| |
|
| | @master_only |
| | def before_run(self, runner): |
| | super(WandbLoggerHook, self).before_run(runner) |
| | if self.wandb is None: |
| | self.import_wandb() |
| | if self.init_kwargs: |
| | self.wandb.init(**self.init_kwargs) |
| | else: |
| | self.wandb.init() |
| |
|
| | @master_only |
| | def log(self, runner): |
| | tags = self.get_loggable_tags(runner) |
| | if tags: |
| | if self.with_step: |
| | self.wandb.log( |
| | tags, step=self.get_iter(runner), commit=self.commit) |
| | else: |
| | tags['global_step'] = self.get_iter(runner) |
| | self.wandb.log(tags, commit=self.commit) |
| |
|
| | @master_only |
| | def after_run(self, runner): |
| | self.wandb.join() |
| |
|