| | |
| | |
| |
|
| | import datetime |
| | import itertools |
| | import logging |
| | import math |
| | import operator |
| | import os |
| | import tempfile |
| | import time |
| | import warnings |
| | from collections import Counter |
| | import torch |
| | from fvcore.common.checkpoint import Checkpointer |
| | from fvcore.common.checkpoint import PeriodicCheckpointer as _PeriodicCheckpointer |
| | from fvcore.common.param_scheduler import ParamScheduler |
| | from fvcore.common.timer import Timer |
| | from fvcore.nn.precise_bn import get_bn_modules, update_bn_stats |
| |
|
| | import detectron2.utils.comm as comm |
| | from detectron2.evaluation.testing import flatten_results_dict |
| | from detectron2.solver import LRMultiplier |
| | from detectron2.solver import LRScheduler as _LRScheduler |
| | from detectron2.utils.events import EventStorage, EventWriter |
| | from detectron2.utils.file_io import PathManager |
| |
|
| | from .train_loop import HookBase |
| |
|
| | __all__ = [ |
| | "CallbackHook", |
| | "IterationTimer", |
| | "PeriodicWriter", |
| | "PeriodicCheckpointer", |
| | "BestCheckpointer", |
| | "LRScheduler", |
| | "AutogradProfiler", |
| | "EvalHook", |
| | "PreciseBN", |
| | "TorchProfiler", |
| | "TorchMemoryStats", |
| | ] |
| |
|
| |
|
| | """ |
| | Implement some common hooks. |
| | """ |
| |
|
| |
|
| | class CallbackHook(HookBase): |
| | """ |
| | Create a hook using callback functions provided by the user. |
| | """ |
| |
|
| | def __init__(self, *, before_train=None, after_train=None, before_step=None, after_step=None): |
| | """ |
| | Each argument is a function that takes one argument: the trainer. |
| | """ |
| | self._before_train = before_train |
| | self._before_step = before_step |
| | self._after_step = after_step |
| | self._after_train = after_train |
| |
|
| | def before_train(self): |
| | if self._before_train: |
| | self._before_train(self.trainer) |
| |
|
| | def after_train(self): |
| | if self._after_train: |
| | self._after_train(self.trainer) |
| | |
| | |
| | del self._before_train, self._after_train |
| | del self._before_step, self._after_step |
| |
|
| | def before_step(self): |
| | if self._before_step: |
| | self._before_step(self.trainer) |
| |
|
| | def after_step(self): |
| | if self._after_step: |
| | self._after_step(self.trainer) |
| |
|
| |
|
| | class IterationTimer(HookBase): |
| | """ |
| | Track the time spent for each iteration (each run_step call in the trainer). |
| | Print a summary in the end of training. |
| | |
| | This hook uses the time between the call to its :meth:`before_step` |
| | and :meth:`after_step` methods. |
| | Under the convention that :meth:`before_step` of all hooks should only |
| | take negligible amount of time, the :class:`IterationTimer` hook should be |
| | placed at the beginning of the list of hooks to obtain accurate timing. |
| | """ |
| |
|
| | def __init__(self, warmup_iter=3): |
| | """ |
| | Args: |
| | warmup_iter (int): the number of iterations at the beginning to exclude |
| | from timing. |
| | """ |
| | self._warmup_iter = warmup_iter |
| | self._step_timer = Timer() |
| | self._start_time = time.perf_counter() |
| | self._total_timer = Timer() |
| |
|
| | def before_train(self): |
| | self._start_time = time.perf_counter() |
| | self._total_timer.reset() |
| | self._total_timer.pause() |
| |
|
| | def after_train(self): |
| | logger = logging.getLogger(__name__) |
| | total_time = time.perf_counter() - self._start_time |
| | total_time_minus_hooks = self._total_timer.seconds() |
| | hook_time = total_time - total_time_minus_hooks |
| |
|
| | num_iter = self.trainer.storage.iter + 1 - self.trainer.start_iter - self._warmup_iter |
| |
|
| | if num_iter > 0 and total_time_minus_hooks > 0: |
| | |
| | |
| | logger.info( |
| | "Overall training speed: {} iterations in {} ({:.4f} s / it)".format( |
| | num_iter, |
| | str(datetime.timedelta(seconds=int(total_time_minus_hooks))), |
| | total_time_minus_hooks / num_iter, |
| | ) |
| | ) |
| |
|
| | logger.info( |
| | "Total training time: {} ({} on hooks)".format( |
| | str(datetime.timedelta(seconds=int(total_time))), |
| | str(datetime.timedelta(seconds=int(hook_time))), |
| | ) |
| | ) |
| |
|
| | def before_step(self): |
| | self._step_timer.reset() |
| | self._total_timer.resume() |
| |
|
| | def after_step(self): |
| | |
| | |
| | iter_done = self.trainer.storage.iter - self.trainer.start_iter + 1 |
| | if iter_done >= self._warmup_iter: |
| | sec = self._step_timer.seconds() |
| | self.trainer.storage.put_scalars(time=sec) |
| | else: |
| | self._start_time = time.perf_counter() |
| | self._total_timer.reset() |
| |
|
| | self._total_timer.pause() |
| |
|
| |
|
| | class PeriodicWriter(HookBase): |
| | """ |
| | Write events to EventStorage (by calling ``writer.write()``) periodically. |
| | |
| | It is executed every ``period`` iterations and after the last iteration. |
| | Note that ``period`` does not affect how data is smoothed by each writer. |
| | """ |
| |
|
| | def __init__(self, writers, period=20): |
| | """ |
| | Args: |
| | writers (list[EventWriter]): a list of EventWriter objects |
| | period (int): |
| | """ |
| | self._writers = writers |
| | for w in writers: |
| | assert isinstance(w, EventWriter), w |
| | self._period = period |
| |
|
| | def after_step(self): |
| | if (self.trainer.iter + 1) % self._period == 0 or ( |
| | self.trainer.iter == self.trainer.max_iter - 1 |
| | ): |
| | for writer in self._writers: |
| | writer.write() |
| |
|
| | def after_train(self): |
| | for writer in self._writers: |
| | |
| | |
| | writer.write() |
| | writer.close() |
| |
|
| |
|
| | class PeriodicCheckpointer(_PeriodicCheckpointer, HookBase): |
| | """ |
| | Same as :class:`detectron2.checkpoint.PeriodicCheckpointer`, but as a hook. |
| | |
| | Note that when used as a hook, |
| | it is unable to save additional data other than what's defined |
| | by the given `checkpointer`. |
| | |
| | It is executed every ``period`` iterations and after the last iteration. |
| | """ |
| |
|
| | def before_train(self): |
| | self.max_iter = self.trainer.max_iter |
| |
|
| | def after_step(self): |
| | |
| | self.step(self.trainer.iter) |
| |
|
| |
|
| | class BestCheckpointer(HookBase): |
| | """ |
| | Checkpoints best weights based off given metric. |
| | |
| | This hook should be used in conjunction to and executed after the hook |
| | that produces the metric, e.g. `EvalHook`. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | eval_period: int, |
| | checkpointer: Checkpointer, |
| | val_metric: str, |
| | mode: str = "max", |
| | file_prefix: str = "model_best", |
| | ) -> None: |
| | """ |
| | Args: |
| | eval_period (int): the period `EvalHook` is set to run. |
| | checkpointer: the checkpointer object used to save checkpoints. |
| | val_metric (str): validation metric to track for best checkpoint, e.g. "bbox/AP50" |
| | mode (str): one of {'max', 'min'}. controls whether the chosen val metric should be |
| | maximized or minimized, e.g. for "bbox/AP50" it should be "max" |
| | file_prefix (str): the prefix of checkpoint's filename, defaults to "model_best" |
| | """ |
| | self._logger = logging.getLogger(__name__) |
| | self._period = eval_period |
| | self._val_metric = val_metric |
| | assert mode in [ |
| | "max", |
| | "min", |
| | ], f'Mode "{mode}" to `BestCheckpointer` is unknown. It should be one of {"max", "min"}.' |
| | if mode == "max": |
| | self._compare = operator.gt |
| | else: |
| | self._compare = operator.lt |
| | self._checkpointer = checkpointer |
| | self._file_prefix = file_prefix |
| | self.best_metric = None |
| | self.best_iter = None |
| |
|
| | def _update_best(self, val, iteration): |
| | if math.isnan(val) or math.isinf(val): |
| | return False |
| | self.best_metric = val |
| | self.best_iter = iteration |
| | return True |
| |
|
| | def _best_checking(self): |
| | metric_tuple = self.trainer.storage.latest().get(self._val_metric) |
| | if metric_tuple is None: |
| | self._logger.warning( |
| | f"Given val metric {self._val_metric} does not seem to be computed/stored." |
| | "Will not be checkpointing based on it." |
| | ) |
| | return |
| | else: |
| | latest_metric, metric_iter = metric_tuple |
| |
|
| | if self.best_metric is None: |
| | if self._update_best(latest_metric, metric_iter): |
| | additional_state = {"iteration": metric_iter} |
| | self._checkpointer.save(f"{self._file_prefix}", **additional_state) |
| | self._logger.info( |
| | f"Saved first model at {self.best_metric:0.5f} @ {self.best_iter} steps" |
| | ) |
| | elif self._compare(latest_metric, self.best_metric): |
| | additional_state = {"iteration": metric_iter} |
| | self._checkpointer.save(f"{self._file_prefix}", **additional_state) |
| | self._logger.info( |
| | f"Saved best model as latest eval score for {self._val_metric} is " |
| | f"{latest_metric:0.5f}, better than last best score " |
| | f"{self.best_metric:0.5f} @ iteration {self.best_iter}." |
| | ) |
| | self._update_best(latest_metric, metric_iter) |
| | else: |
| | self._logger.info( |
| | f"Not saving as latest eval score for {self._val_metric} is {latest_metric:0.5f}, " |
| | f"not better than best score {self.best_metric:0.5f} @ iteration {self.best_iter}." |
| | ) |
| |
|
| | def after_step(self): |
| | |
| | next_iter = self.trainer.iter + 1 |
| | if ( |
| | self._period > 0 |
| | and next_iter % self._period == 0 |
| | and next_iter != self.trainer.max_iter |
| | ): |
| | self._best_checking() |
| |
|
| | def after_train(self): |
| | |
| | if self.trainer.iter + 1 >= self.trainer.max_iter: |
| | self._best_checking() |
| |
|
| |
|
| | class LRScheduler(HookBase): |
| | """ |
| | A hook which executes a torch builtin LR scheduler and summarizes the LR. |
| | It is executed after every iteration. |
| | """ |
| |
|
| | def __init__(self, optimizer=None, scheduler=None): |
| | """ |
| | Args: |
| | optimizer (torch.optim.Optimizer): |
| | scheduler (torch.optim.LRScheduler or fvcore.common.param_scheduler.ParamScheduler): |
| | if a :class:`ParamScheduler` object, it defines the multiplier over the base LR |
| | in the optimizer. |
| | |
| | If any argument is not given, will try to obtain it from the trainer. |
| | """ |
| | self._optimizer = optimizer |
| | self._scheduler = scheduler |
| |
|
| | def before_train(self): |
| | self._optimizer = self._optimizer or self.trainer.optimizer |
| | if isinstance(self.scheduler, ParamScheduler): |
| | self._scheduler = LRMultiplier( |
| | self._optimizer, |
| | self.scheduler, |
| | self.trainer.max_iter, |
| | last_iter=self.trainer.iter - 1, |
| | ) |
| | self._best_param_group_id = LRScheduler.get_best_param_group_id(self._optimizer) |
| |
|
| | @staticmethod |
| | def get_best_param_group_id(optimizer): |
| | |
| | |
| | largest_group = max(len(g["params"]) for g in optimizer.param_groups) |
| |
|
| | if largest_group == 1: |
| | |
| | |
| | lr_count = Counter([g["lr"] for g in optimizer.param_groups]) |
| | lr = lr_count.most_common()[0][0] |
| | for i, g in enumerate(optimizer.param_groups): |
| | if g["lr"] == lr: |
| | return i |
| | else: |
| | for i, g in enumerate(optimizer.param_groups): |
| | if len(g["params"]) == largest_group: |
| | return i |
| |
|
| | def after_step(self): |
| | lr = self._optimizer.param_groups[self._best_param_group_id]["lr"] |
| | self.trainer.storage.put_scalar("lr", lr, smoothing_hint=False) |
| | self.scheduler.step() |
| |
|
| | @property |
| | def scheduler(self): |
| | return self._scheduler or self.trainer.scheduler |
| |
|
| | def state_dict(self): |
| | if isinstance(self.scheduler, _LRScheduler): |
| | return self.scheduler.state_dict() |
| | return {} |
| |
|
| | def load_state_dict(self, state_dict): |
| | if isinstance(self.scheduler, _LRScheduler): |
| | logger = logging.getLogger(__name__) |
| | logger.info("Loading scheduler from state_dict ...") |
| | self.scheduler.load_state_dict(state_dict) |
| |
|
| |
|
| | class TorchProfiler(HookBase): |
| | """ |
| | A hook which runs `torch.profiler.profile`. |
| | |
| | Examples: |
| | :: |
| | hooks.TorchProfiler( |
| | lambda trainer: 10 < trainer.iter < 20, self.cfg.OUTPUT_DIR |
| | ) |
| | |
| | The above example will run the profiler for iteration 10~20 and dump |
| | results to ``OUTPUT_DIR``. We did not profile the first few iterations |
| | because they are typically slower than the rest. |
| | The result files can be loaded in the ``chrome://tracing`` page in chrome browser, |
| | and the tensorboard visualizations can be visualized using |
| | ``tensorboard --logdir OUTPUT_DIR/log`` |
| | """ |
| |
|
| | def __init__(self, enable_predicate, output_dir, *, activities=None, save_tensorboard=True): |
| | """ |
| | Args: |
| | enable_predicate (callable[trainer -> bool]): a function which takes a trainer, |
| | and returns whether to enable the profiler. |
| | It will be called once every step, and can be used to select which steps to profile. |
| | output_dir (str): the output directory to dump tracing files. |
| | activities (iterable): same as in `torch.profiler.profile`. |
| | save_tensorboard (bool): whether to save tensorboard visualizations at (output_dir)/log/ |
| | """ |
| | self._enable_predicate = enable_predicate |
| | self._activities = activities |
| | self._output_dir = output_dir |
| | self._save_tensorboard = save_tensorboard |
| |
|
| | def before_step(self): |
| | if self._enable_predicate(self.trainer): |
| | if self._save_tensorboard: |
| | on_trace_ready = torch.profiler.tensorboard_trace_handler( |
| | os.path.join( |
| | self._output_dir, |
| | "log", |
| | "profiler-tensorboard-iter{}".format(self.trainer.iter), |
| | ), |
| | f"worker{comm.get_rank()}", |
| | ) |
| | else: |
| | on_trace_ready = None |
| | self._profiler = torch.profiler.profile( |
| | activities=self._activities, |
| | on_trace_ready=on_trace_ready, |
| | record_shapes=True, |
| | profile_memory=True, |
| | with_stack=True, |
| | with_flops=True, |
| | ) |
| | self._profiler.__enter__() |
| | else: |
| | self._profiler = None |
| |
|
| | def after_step(self): |
| | if self._profiler is None: |
| | return |
| | self._profiler.__exit__(None, None, None) |
| | if not self._save_tensorboard: |
| | PathManager.mkdirs(self._output_dir) |
| | out_file = os.path.join( |
| | self._output_dir, "profiler-trace-iter{}.json".format(self.trainer.iter) |
| | ) |
| | if "://" not in out_file: |
| | self._profiler.export_chrome_trace(out_file) |
| | else: |
| | |
| | with tempfile.TemporaryDirectory(prefix="detectron2_profiler") as d: |
| | tmp_file = os.path.join(d, "tmp.json") |
| | self._profiler.export_chrome_trace(tmp_file) |
| | with open(tmp_file) as f: |
| | content = f.read() |
| | with PathManager.open(out_file, "w") as f: |
| | f.write(content) |
| |
|
| |
|
| | class AutogradProfiler(TorchProfiler): |
| | """ |
| | A hook which runs `torch.autograd.profiler.profile`. |
| | |
| | Examples: |
| | :: |
| | hooks.AutogradProfiler( |
| | lambda trainer: 10 < trainer.iter < 20, self.cfg.OUTPUT_DIR |
| | ) |
| | |
| | The above example will run the profiler for iteration 10~20 and dump |
| | results to ``OUTPUT_DIR``. We did not profile the first few iterations |
| | because they are typically slower than the rest. |
| | The result files can be loaded in the ``chrome://tracing`` page in chrome browser. |
| | |
| | Note: |
| | When used together with NCCL on older version of GPUs, |
| | autograd profiler may cause deadlock because it unnecessarily allocates |
| | memory on every device it sees. The memory management calls, if |
| | interleaved with NCCL calls, lead to deadlock on GPUs that do not |
| | support ``cudaLaunchCooperativeKernelMultiDevice``. |
| | """ |
| |
|
| | def __init__(self, enable_predicate, output_dir, *, use_cuda=True): |
| | """ |
| | Args: |
| | enable_predicate (callable[trainer -> bool]): a function which takes a trainer, |
| | and returns whether to enable the profiler. |
| | It will be called once every step, and can be used to select which steps to profile. |
| | output_dir (str): the output directory to dump tracing files. |
| | use_cuda (bool): same as in `torch.autograd.profiler.profile`. |
| | """ |
| | warnings.warn("AutogradProfiler has been deprecated in favor of TorchProfiler.") |
| | self._enable_predicate = enable_predicate |
| | self._use_cuda = use_cuda |
| | self._output_dir = output_dir |
| |
|
| | def before_step(self): |
| | if self._enable_predicate(self.trainer): |
| | self._profiler = torch.autograd.profiler.profile(use_cuda=self._use_cuda) |
| | self._profiler.__enter__() |
| | else: |
| | self._profiler = None |
| |
|
| |
|
| | class EvalHook(HookBase): |
| | """ |
| | Run an evaluation function periodically, and at the end of training. |
| | |
| | It is executed every ``eval_period`` iterations and after the last iteration. |
| | """ |
| |
|
| | def __init__(self, eval_period, eval_function, eval_after_train=True): |
| | """ |
| | Args: |
| | eval_period (int): the period to run `eval_function`. Set to 0 to |
| | not evaluate periodically (but still evaluate after the last iteration |
| | if `eval_after_train` is True). |
| | eval_function (callable): a function which takes no arguments, and |
| | returns a nested dict of evaluation metrics. |
| | eval_after_train (bool): whether to evaluate after the last iteration |
| | |
| | Note: |
| | This hook must be enabled in all or none workers. |
| | If you would like only certain workers to perform evaluation, |
| | give other workers a no-op function (`eval_function=lambda: None`). |
| | """ |
| | self._period = eval_period |
| | self._func = eval_function |
| | self._eval_after_train = eval_after_train |
| |
|
| | def _do_eval(self): |
| | results = self._func() |
| |
|
| | if results: |
| | assert isinstance( |
| | results, dict |
| | ), "Eval function must return a dict. Got {} instead.".format(results) |
| |
|
| | flattened_results = flatten_results_dict(results) |
| | for k, v in flattened_results.items(): |
| | try: |
| | v = float(v) |
| | except Exception as e: |
| | raise ValueError( |
| | "[EvalHook] eval_function should return a nested dict of float. " |
| | "Got '{}: {}' instead.".format(k, v) |
| | ) from e |
| | self.trainer.storage.put_scalars(**flattened_results, smoothing_hint=False) |
| |
|
| | |
| | |
| | comm.synchronize() |
| |
|
| | def after_step(self): |
| | next_iter = self.trainer.iter + 1 |
| | if self._period > 0 and next_iter % self._period == 0: |
| | |
| | if next_iter != self.trainer.max_iter: |
| | self._do_eval() |
| |
|
| | def after_train(self): |
| | |
| | if self._eval_after_train and self.trainer.iter + 1 >= self.trainer.max_iter: |
| | self._do_eval() |
| | |
| | |
| | del self._func |
| |
|
| |
|
| | class PreciseBN(HookBase): |
| | """ |
| | The standard implementation of BatchNorm uses EMA in inference, which is |
| | sometimes suboptimal. |
| | This class computes the true average of statistics rather than the moving average, |
| | and put true averages to every BN layer in the given model. |
| | |
| | It is executed every ``period`` iterations and after the last iteration. |
| | """ |
| |
|
| | def __init__(self, period, model, data_loader, num_iter): |
| | """ |
| | Args: |
| | period (int): the period this hook is run, or 0 to not run during training. |
| | The hook will always run in the end of training. |
| | model (nn.Module): a module whose all BN layers in training mode will be |
| | updated by precise BN. |
| | Note that user is responsible for ensuring the BN layers to be |
| | updated are in training mode when this hook is triggered. |
| | data_loader (iterable): it will produce data to be run by `model(data)`. |
| | num_iter (int): number of iterations used to compute the precise |
| | statistics. |
| | """ |
| | self._logger = logging.getLogger(__name__) |
| | if len(get_bn_modules(model)) == 0: |
| | self._logger.info( |
| | "PreciseBN is disabled because model does not contain BN layers in training mode." |
| | ) |
| | self._disabled = True |
| | return |
| |
|
| | self._model = model |
| | self._data_loader = data_loader |
| | self._num_iter = num_iter |
| | self._period = period |
| | self._disabled = False |
| |
|
| | self._data_iter = None |
| |
|
| | def after_step(self): |
| | next_iter = self.trainer.iter + 1 |
| | is_final = next_iter == self.trainer.max_iter |
| | if is_final or (self._period > 0 and next_iter % self._period == 0): |
| | self.update_stats() |
| |
|
| | def update_stats(self): |
| | """ |
| | Update the model with precise statistics. Users can manually call this method. |
| | """ |
| | if self._disabled: |
| | return |
| |
|
| | if self._data_iter is None: |
| | self._data_iter = iter(self._data_loader) |
| |
|
| | def data_loader(): |
| | for num_iter in itertools.count(1): |
| | if num_iter % 100 == 0: |
| | self._logger.info( |
| | "Running precise-BN ... {}/{} iterations.".format(num_iter, self._num_iter) |
| | ) |
| | |
| | yield next(self._data_iter) |
| |
|
| | with EventStorage(): |
| | self._logger.info( |
| | "Running precise-BN for {} iterations... ".format(self._num_iter) |
| | + "Note that this could produce different statistics every time." |
| | ) |
| | update_bn_stats(self._model, data_loader(), self._num_iter) |
| |
|
| |
|
| | class TorchMemoryStats(HookBase): |
| | """ |
| | Writes pytorch's cuda memory statistics periodically. |
| | """ |
| |
|
| | def __init__(self, period=20, max_runs=10): |
| | """ |
| | Args: |
| | period (int): Output stats each 'period' iterations |
| | max_runs (int): Stop the logging after 'max_runs' |
| | """ |
| |
|
| | self._logger = logging.getLogger(__name__) |
| | self._period = period |
| | self._max_runs = max_runs |
| | self._runs = 0 |
| |
|
| | def after_step(self): |
| | if self._runs > self._max_runs: |
| | return |
| |
|
| | if (self.trainer.iter + 1) % self._period == 0 or ( |
| | self.trainer.iter == self.trainer.max_iter - 1 |
| | ): |
| | if torch.cuda.is_available(): |
| | max_reserved_mb = torch.cuda.max_memory_reserved() / 1024.0 / 1024.0 |
| | reserved_mb = torch.cuda.memory_reserved() / 1024.0 / 1024.0 |
| | max_allocated_mb = torch.cuda.max_memory_allocated() / 1024.0 / 1024.0 |
| | allocated_mb = torch.cuda.memory_allocated() / 1024.0 / 1024.0 |
| |
|
| | self._logger.info( |
| | ( |
| | " iter: {} " |
| | " max_reserved_mem: {:.0f}MB " |
| | " reserved_mem: {:.0f}MB " |
| | " max_allocated_mem: {:.0f}MB " |
| | " allocated_mem: {:.0f}MB " |
| | ).format( |
| | self.trainer.iter, |
| | max_reserved_mb, |
| | reserved_mb, |
| | max_allocated_mb, |
| | allocated_mb, |
| | ) |
| | ) |
| |
|
| | self._runs += 1 |
| | if self._runs == self._max_runs: |
| | mem_summary = torch.cuda.memory_summary() |
| | self._logger.info("\n" + mem_summary) |
| |
|
| | torch.cuda.reset_peak_memory_stats() |
| |
|