| | |
| | |
| | import concurrent.futures |
| | import logging |
| | import numpy as np |
| | import time |
| | import weakref |
| | from typing import List, Mapping, Optional |
| | import torch |
| | from torch.nn.parallel import DataParallel, DistributedDataParallel |
| |
|
| | import detectron2.utils.comm as comm |
| | from detectron2.utils.events import EventStorage, get_event_storage |
| | from detectron2.utils.logger import _log_api_usage |
| |
|
| | __all__ = ["HookBase", "TrainerBase", "SimpleTrainer", "AMPTrainer"] |
| |
|
| |
|
| | class HookBase: |
| | """ |
| | Base class for hooks that can be registered with :class:`TrainerBase`. |
| | |
| | Each hook can implement 4 methods. The way they are called is demonstrated |
| | in the following snippet: |
| | :: |
| | hook.before_train() |
| | for iter in range(start_iter, max_iter): |
| | hook.before_step() |
| | trainer.run_step() |
| | hook.after_step() |
| | iter += 1 |
| | hook.after_train() |
| | |
| | Notes: |
| | 1. In the hook method, users can access ``self.trainer`` to access more |
| | properties about the context (e.g., model, current iteration, or config |
| | if using :class:`DefaultTrainer`). |
| | |
| | 2. A hook that does something in :meth:`before_step` can often be |
| | implemented equivalently in :meth:`after_step`. |
| | If the hook takes non-trivial time, it is strongly recommended to |
| | implement the hook in :meth:`after_step` instead of :meth:`before_step`. |
| | The convention is that :meth:`before_step` should only take negligible time. |
| | |
| | Following this convention will allow hooks that do care about the difference |
| | between :meth:`before_step` and :meth:`after_step` (e.g., timer) to |
| | function properly. |
| | |
| | """ |
| |
|
| | trainer: "TrainerBase" = None |
| | """ |
| | A weak reference to the trainer object. Set by the trainer when the hook is registered. |
| | """ |
| |
|
| | def before_train(self): |
| | """ |
| | Called before the first iteration. |
| | """ |
| | pass |
| |
|
| | def after_train(self): |
| | """ |
| | Called after the last iteration. |
| | """ |
| | pass |
| |
|
| | def before_step(self): |
| | """ |
| | Called before each iteration. |
| | """ |
| | pass |
| |
|
| | def after_backward(self): |
| | """ |
| | Called after the backward pass of each iteration. |
| | """ |
| | pass |
| |
|
| | def after_step(self): |
| | """ |
| | Called after each iteration. |
| | """ |
| | pass |
| |
|
| | def state_dict(self): |
| | """ |
| | Hooks are stateless by default, but can be made checkpointable by |
| | implementing `state_dict` and `load_state_dict`. |
| | """ |
| | return {} |
| |
|
| |
|
| | class TrainerBase: |
| | """ |
| | Base class for iterative trainer with hooks. |
| | |
| | The only assumption we made here is: the training runs in a loop. |
| | A subclass can implement what the loop is. |
| | We made no assumptions about the existence of dataloader, optimizer, model, etc. |
| | |
| | Attributes: |
| | iter(int): the current iteration. |
| | |
| | start_iter(int): The iteration to start with. |
| | By convention the minimum possible value is 0. |
| | |
| | max_iter(int): The iteration to end training. |
| | |
| | storage(EventStorage): An EventStorage that's opened during the course of training. |
| | """ |
| |
|
| | def __init__(self) -> None: |
| | self._hooks: List[HookBase] = [] |
| | self.iter: int = 0 |
| | self.start_iter: int = 0 |
| | self.max_iter: int |
| | self.storage: EventStorage |
| | _log_api_usage("trainer." + self.__class__.__name__) |
| |
|
| | def register_hooks(self, hooks: List[Optional[HookBase]]) -> None: |
| | """ |
| | Register hooks to the trainer. The hooks are executed in the order |
| | they are registered. |
| | |
| | Args: |
| | hooks (list[Optional[HookBase]]): list of hooks |
| | """ |
| | hooks = [h for h in hooks if h is not None] |
| | for h in hooks: |
| | assert isinstance(h, HookBase) |
| | |
| | |
| | |
| | |
| | h.trainer = weakref.proxy(self) |
| | self._hooks.extend(hooks) |
| |
|
| | def train(self, start_iter: int, max_iter: int): |
| | """ |
| | Args: |
| | start_iter, max_iter (int): See docs above |
| | """ |
| | logger = logging.getLogger(__name__) |
| | logger.info("Starting training from iteration {}".format(start_iter)) |
| |
|
| | self.iter = self.start_iter = start_iter |
| | self.max_iter = max_iter |
| |
|
| | with EventStorage(start_iter) as self.storage: |
| | try: |
| | self.before_train() |
| | for self.iter in range(start_iter, max_iter): |
| | self.before_step() |
| | self.run_step() |
| | self.after_step() |
| | |
| | |
| | |
| | self.iter += 1 |
| | except Exception: |
| | logger.exception("Exception during training:") |
| | raise |
| | finally: |
| | self.after_train() |
| |
|
| | def before_train(self): |
| | for h in self._hooks: |
| | h.before_train() |
| |
|
| | def after_train(self): |
| | self.storage.iter = self.iter |
| | for h in self._hooks: |
| | h.after_train() |
| |
|
| | def before_step(self): |
| | |
| | |
| | self.storage.iter = self.iter |
| |
|
| | for h in self._hooks: |
| | h.before_step() |
| |
|
| | def after_backward(self): |
| | for h in self._hooks: |
| | h.after_backward() |
| |
|
| | def after_step(self): |
| | for h in self._hooks: |
| | h.after_step() |
| |
|
| | def run_step(self): |
| | raise NotImplementedError |
| |
|
| | def state_dict(self): |
| | ret = {"iteration": self.iter} |
| | hooks_state = {} |
| | for h in self._hooks: |
| | sd = h.state_dict() |
| | if sd: |
| | name = type(h).__qualname__ |
| | if name in hooks_state: |
| | |
| | continue |
| | hooks_state[name] = sd |
| | if hooks_state: |
| | ret["hooks"] = hooks_state |
| | return ret |
| |
|
| | def load_state_dict(self, state_dict): |
| | logger = logging.getLogger(__name__) |
| | self.iter = state_dict["iteration"] |
| | for key, value in state_dict.get("hooks", {}).items(): |
| | for h in self._hooks: |
| | try: |
| | name = type(h).__qualname__ |
| | except AttributeError: |
| | continue |
| | if name == key: |
| | h.load_state_dict(value) |
| | break |
| | else: |
| | logger.warning(f"Cannot find the hook '{key}', its state_dict is ignored.") |
| |
|
| |
|
| | class SimpleTrainer(TrainerBase): |
| | """ |
| | A simple trainer for the most common type of task: |
| | single-cost single-optimizer single-data-source iterative optimization, |
| | optionally using data-parallelism. |
| | It assumes that every step, you: |
| | |
| | 1. Compute the loss with a data from the data_loader. |
| | 2. Compute the gradients with the above loss. |
| | 3. Update the model with the optimizer. |
| | |
| | All other tasks during training (checkpointing, logging, evaluation, LR schedule) |
| | are maintained by hooks, which can be registered by :meth:`TrainerBase.register_hooks`. |
| | |
| | If you want to do anything fancier than this, |
| | either subclass TrainerBase and implement your own `run_step`, |
| | or write your own training loop. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | model, |
| | data_loader, |
| | optimizer, |
| | gather_metric_period=1, |
| | zero_grad_before_forward=False, |
| | async_write_metrics=False, |
| | ): |
| | """ |
| | Args: |
| | model: a torch Module. Takes a data from data_loader and returns a |
| | dict of losses. |
| | data_loader: an iterable. Contains data to be used to call model. |
| | optimizer: a torch optimizer. |
| | gather_metric_period: an int. Every gather_metric_period iterations |
| | the metrics are gathered from all the ranks to rank 0 and logged. |
| | zero_grad_before_forward: whether to zero the gradients before the forward. |
| | async_write_metrics: bool. If True, then write metrics asynchronously to improve |
| | training speed |
| | """ |
| | super().__init__() |
| |
|
| | """ |
| | We set the model to training mode in the trainer. |
| | However it's valid to train a model that's in eval mode. |
| | If you want your model (or a submodule of it) to behave |
| | like evaluation during training, you can overwrite its train() method. |
| | """ |
| | model.train() |
| |
|
| | self.model = model |
| | self.data_loader = data_loader |
| | |
| | self._data_loader_iter_obj = None |
| | self.optimizer = optimizer |
| | self.gather_metric_period = gather_metric_period |
| | self.zero_grad_before_forward = zero_grad_before_forward |
| | self.async_write_metrics = async_write_metrics |
| | |
| | |
| | self.concurrent_executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) |
| |
|
| | def run_step(self): |
| | """ |
| | Implement the standard training logic described above. |
| | """ |
| | assert self.model.training, "[SimpleTrainer] model was changed to eval mode!" |
| | start = time.perf_counter() |
| | """ |
| | If you want to do something with the data, you can wrap the dataloader. |
| | """ |
| | data = next(self._data_loader_iter) |
| | data_time = time.perf_counter() - start |
| |
|
| | if self.zero_grad_before_forward: |
| | """ |
| | If you need to accumulate gradients or do something similar, you can |
| | wrap the optimizer with your custom `zero_grad()` method. |
| | """ |
| | self.optimizer.zero_grad() |
| |
|
| | """ |
| | If you want to do something with the losses, you can wrap the model. |
| | """ |
| | loss_dict = self.model(data) |
| | if isinstance(loss_dict, torch.Tensor): |
| | losses = loss_dict |
| | loss_dict = {"total_loss": loss_dict} |
| | else: |
| | losses = sum(loss_dict.values()) |
| | if not self.zero_grad_before_forward: |
| | """ |
| | If you need to accumulate gradients or do something similar, you can |
| | wrap the optimizer with your custom `zero_grad()` method. |
| | """ |
| | self.optimizer.zero_grad() |
| | losses.backward() |
| |
|
| | self.after_backward() |
| |
|
| | if self.async_write_metrics: |
| | |
| | self.concurrent_executor.submit( |
| | self._write_metrics, loss_dict, data_time, iter=self.iter |
| | ) |
| | else: |
| | self._write_metrics(loss_dict, data_time) |
| |
|
| | """ |
| | If you need gradient clipping/scaling or other processing, you can |
| | wrap the optimizer with your custom `step()` method. But it is |
| | suboptimal as explained in https://arxiv.org/abs/2006.15704 Sec 3.2.4 |
| | """ |
| | self.optimizer.step() |
| |
|
| | @property |
| | def _data_loader_iter(self): |
| | |
| | if self._data_loader_iter_obj is None: |
| | self._data_loader_iter_obj = iter(self.data_loader) |
| | return self._data_loader_iter_obj |
| |
|
| | def reset_data_loader(self, data_loader_builder): |
| | """ |
| | Delete and replace the current data loader with a new one, which will be created |
| | by calling `data_loader_builder` (without argument). |
| | """ |
| | del self.data_loader |
| | data_loader = data_loader_builder() |
| | self.data_loader = data_loader |
| | self._data_loader_iter_obj = None |
| |
|
| | def _write_metrics( |
| | self, |
| | loss_dict: Mapping[str, torch.Tensor], |
| | data_time: float, |
| | prefix: str = "", |
| | iter: Optional[int] = None, |
| | ) -> None: |
| | logger = logging.getLogger(__name__) |
| |
|
| | iter = self.iter if iter is None else iter |
| | if (iter + 1) % self.gather_metric_period == 0: |
| | try: |
| | SimpleTrainer.write_metrics(loss_dict, data_time, iter, prefix) |
| | except Exception: |
| | logger.exception("Exception in writing metrics: ") |
| | raise |
| |
|
| | @staticmethod |
| | def write_metrics( |
| | loss_dict: Mapping[str, torch.Tensor], |
| | data_time: float, |
| | cur_iter: int, |
| | prefix: str = "", |
| | ) -> None: |
| | """ |
| | Args: |
| | loss_dict (dict): dict of scalar losses |
| | data_time (float): time taken by the dataloader iteration |
| | prefix (str): prefix for logging keys |
| | """ |
| | metrics_dict = {k: v.detach().cpu().item() for k, v in loss_dict.items()} |
| | metrics_dict["data_time"] = data_time |
| |
|
| | |
| | |
| | |
| | all_metrics_dict = comm.gather(metrics_dict) |
| |
|
| | if comm.is_main_process(): |
| | storage = get_event_storage() |
| |
|
| | |
| | |
| | data_time = np.max([x.pop("data_time") for x in all_metrics_dict]) |
| | storage.put_scalar("data_time", data_time, cur_iter=cur_iter) |
| |
|
| | |
| | metrics_dict = { |
| | k: np.mean([x[k] for x in all_metrics_dict]) for k in all_metrics_dict[0].keys() |
| | } |
| | total_losses_reduced = sum(metrics_dict.values()) |
| | if not np.isfinite(total_losses_reduced): |
| | raise FloatingPointError( |
| | f"Loss became infinite or NaN at iteration={cur_iter}!\n" |
| | f"loss_dict = {metrics_dict}" |
| | ) |
| |
|
| | storage.put_scalar( |
| | "{}total_loss".format(prefix), total_losses_reduced, cur_iter=cur_iter |
| | ) |
| | if len(metrics_dict) > 1: |
| | storage.put_scalars(cur_iter=cur_iter, **metrics_dict) |
| |
|
| | def state_dict(self): |
| | ret = super().state_dict() |
| | ret["optimizer"] = self.optimizer.state_dict() |
| | return ret |
| |
|
| | def load_state_dict(self, state_dict): |
| | super().load_state_dict(state_dict) |
| | self.optimizer.load_state_dict(state_dict["optimizer"]) |
| |
|
| | def after_train(self): |
| | super().after_train() |
| | self.concurrent_executor.shutdown(wait=True) |
| |
|
| |
|
| | class AMPTrainer(SimpleTrainer): |
| | """ |
| | Like :class:`SimpleTrainer`, but uses PyTorch's native automatic mixed precision |
| | in the training loop. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | model, |
| | data_loader, |
| | optimizer, |
| | gather_metric_period=1, |
| | zero_grad_before_forward=False, |
| | grad_scaler=None, |
| | precision: torch.dtype = torch.float16, |
| | log_grad_scaler: bool = False, |
| | async_write_metrics=False, |
| | ): |
| | """ |
| | Args: |
| | model, data_loader, optimizer, gather_metric_period, zero_grad_before_forward, |
| | async_write_metrics: same as in :class:`SimpleTrainer`. |
| | grad_scaler: torch GradScaler to automatically scale gradients. |
| | precision: torch.dtype as the target precision to cast to in computations |
| | """ |
| | unsupported = "AMPTrainer does not support single-process multi-device training!" |
| | if isinstance(model, DistributedDataParallel): |
| | assert not (model.device_ids and len(model.device_ids) > 1), unsupported |
| | assert not isinstance(model, DataParallel), unsupported |
| |
|
| | super().__init__( |
| | model, data_loader, optimizer, gather_metric_period, zero_grad_before_forward |
| | ) |
| |
|
| | if grad_scaler is None: |
| | from torch.cuda.amp import GradScaler |
| |
|
| | grad_scaler = GradScaler() |
| | self.grad_scaler = grad_scaler |
| | self.precision = precision |
| | self.log_grad_scaler = log_grad_scaler |
| |
|
| | def run_step(self): |
| | """ |
| | Implement the AMP training logic. |
| | """ |
| | assert self.model.training, "[AMPTrainer] model was changed to eval mode!" |
| | assert torch.cuda.is_available(), "[AMPTrainer] CUDA is required for AMP training!" |
| | from torch.cuda.amp import autocast |
| |
|
| | start = time.perf_counter() |
| | data = next(self._data_loader_iter) |
| | data_time = time.perf_counter() - start |
| |
|
| | if self.zero_grad_before_forward: |
| | self.optimizer.zero_grad() |
| | with autocast(dtype=self.precision): |
| | loss_dict = self.model(data) |
| | if isinstance(loss_dict, torch.Tensor): |
| | losses = loss_dict |
| | loss_dict = {"total_loss": loss_dict} |
| | else: |
| | losses = sum(loss_dict.values()) |
| |
|
| | if not self.zero_grad_before_forward: |
| | self.optimizer.zero_grad() |
| |
|
| | self.grad_scaler.scale(losses).backward() |
| |
|
| | if self.log_grad_scaler: |
| | storage = get_event_storage() |
| | storage.put_scalar("[metric]grad_scaler", self.grad_scaler.get_scale()) |
| |
|
| | self.after_backward() |
| |
|
| | if self.async_write_metrics: |
| | |
| | self.concurrent_executor.submit( |
| | self._write_metrics, loss_dict, data_time, iter=self.iter |
| | ) |
| | else: |
| | self._write_metrics(loss_dict, data_time) |
| |
|
| | self.grad_scaler.step(self.optimizer) |
| | self.grad_scaler.update() |
| |
|
| | def state_dict(self): |
| | ret = super().state_dict() |
| | ret["grad_scaler"] = self.grad_scaler.state_dict() |
| | return ret |
| |
|
| | def load_state_dict(self, state_dict): |
| | super().load_state_dict(state_dict) |
| | self.grad_scaler.load_state_dict(state_dict["grad_scaler"]) |
| |
|