|
|
|
|
|
|
|
|
import concurrent.futures |
|
|
import logging |
|
|
import time |
|
|
import weakref |
|
|
from typing import List, Mapping, Optional |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
from torch.nn.parallel import DataParallel, DistributedDataParallel |
|
|
|
|
|
import detectron2.utils.comm as comm |
|
|
from detectron2.engine.train_loop import HookBase, TrainerBase |
|
|
from detectron2.utils.events import EventStorage, get_event_storage |
|
|
from detectron2.utils.logger import _log_api_usage |
|
|
|
|
|
__all__ = ["SimpleTrainer", "AMPTrainer"] |
|
|
|
|
|
|
|
|
class SimpleTrainer(TrainerBase): |
|
|
""" |
|
|
A simple trainer for the most common type of task: |
|
|
single-cost single-optimizer single-data-source iterative optimization, |
|
|
optionally using data-parallelism. |
|
|
It assumes that every step, you: |
|
|
|
|
|
1. Compute the loss with a data from the data_loader. |
|
|
2. Compute the gradients with the above loss. |
|
|
3. Update the model with the optimizer. |
|
|
|
|
|
All other tasks during training (checkpointing, logging, evaluation, LR schedule) |
|
|
are maintained by hooks, which can be registered by :meth:`TrainerBase.register_hooks`. |
|
|
|
|
|
If you want to do anything fancier than this, |
|
|
either subclass TrainerBase and implement your own `run_step`, |
|
|
or write your own training loop. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
model, |
|
|
data_loader, |
|
|
optimizer, |
|
|
gather_metric_period=1, |
|
|
zero_grad_before_forward=False, |
|
|
async_write_metrics=False, |
|
|
): |
|
|
""" |
|
|
Args: |
|
|
model: a torch Module. Takes a data from data_loader and returns a |
|
|
dict of losses. |
|
|
data_loader: an iterable. Contains data to be used to call model. |
|
|
optimizer: a torch optimizer. |
|
|
gather_metric_period: an int. Every gather_metric_period iterations |
|
|
the metrics are gathered from all the ranks to rank 0 and logged. |
|
|
zero_grad_before_forward: whether to zero the gradients before the forward. |
|
|
async_write_metrics: bool. If True, then write metrics asynchronously to improve |
|
|
training speed |
|
|
""" |
|
|
super().__init__() |
|
|
|
|
|
""" |
|
|
We set the model to training mode in the trainer. |
|
|
However it's valid to train a model that's in eval mode. |
|
|
If you want your model (or a submodule of it) to behave |
|
|
like evaluation during training, you can overwrite its train() method. |
|
|
""" |
|
|
model.train() |
|
|
|
|
|
self.model = model |
|
|
self.data_loader = data_loader |
|
|
|
|
|
self._data_loader_iter_obj = None |
|
|
self.optimizer = optimizer |
|
|
self.gather_metric_period = gather_metric_period |
|
|
self.zero_grad_before_forward = zero_grad_before_forward |
|
|
self.async_write_metrics = async_write_metrics |
|
|
|
|
|
|
|
|
self.concurrent_executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) |
|
|
|
|
|
def run_step(self): |
|
|
""" |
|
|
Implement the standard training logic described above. |
|
|
""" |
|
|
assert self.model.training, "[SimpleTrainer] model was changed to eval mode!" |
|
|
start = time.perf_counter() |
|
|
""" |
|
|
If you want to do something with the data, you can wrap the dataloader. |
|
|
""" |
|
|
data = next(self._data_loader_iter) |
|
|
data_time = time.perf_counter() - start |
|
|
|
|
|
|
|
|
for d in data: |
|
|
self.dataset_image_counts[self.dataset_names[d.get("dataset_id", 0)]] += 1 |
|
|
self.dataset_object_counts[self.dataset_names[d.get("dataset_id", 0)]] += len( |
|
|
d.get("instances", []) |
|
|
) |
|
|
dataset_image_counts = {f"count_image/{k}": v for k, v in self.dataset_image_counts.items()} |
|
|
dataset_object_counts = { |
|
|
f"count_object/{k}": v for k, v in self.dataset_object_counts.items() |
|
|
} |
|
|
if self.async_write_metrics: |
|
|
|
|
|
self.concurrent_executor.submit( |
|
|
self._write_metrics_common, dataset_image_counts, iter=self.iter |
|
|
) |
|
|
self.concurrent_executor.submit( |
|
|
self._write_metrics_common, dataset_object_counts, iter=self.iter |
|
|
) |
|
|
else: |
|
|
self._write_metrics_common(dataset_image_counts) |
|
|
self._write_metrics_common(dataset_object_counts) |
|
|
|
|
|
|
|
|
if self.zero_grad_before_forward: |
|
|
""" |
|
|
If you need to accumulate gradients or do something similar, you can |
|
|
wrap the optimizer with your custom `zero_grad()` method. |
|
|
""" |
|
|
self.optimizer.zero_grad() |
|
|
|
|
|
""" |
|
|
If you want to do something with the losses, you can wrap the model. |
|
|
""" |
|
|
loss_dict = self.model(data) |
|
|
if isinstance(loss_dict, torch.Tensor): |
|
|
losses = loss_dict |
|
|
loss_dict = {"total_loss": loss_dict} |
|
|
else: |
|
|
losses = sum(loss_dict.values()) |
|
|
if not self.zero_grad_before_forward: |
|
|
""" |
|
|
If you need to accumulate gradients or do something similar, you can |
|
|
wrap the optimizer with your custom `zero_grad()` method. |
|
|
""" |
|
|
self.optimizer.zero_grad() |
|
|
losses.backward() |
|
|
|
|
|
self.after_backward() |
|
|
|
|
|
if self.async_write_metrics: |
|
|
|
|
|
self.concurrent_executor.submit( |
|
|
self._write_metrics, loss_dict, data_time, iter=self.iter |
|
|
) |
|
|
else: |
|
|
self._write_metrics(loss_dict, data_time) |
|
|
|
|
|
""" |
|
|
If you need gradient clipping/scaling or other processing, you can |
|
|
wrap the optimizer with your custom `step()` method. But it is |
|
|
suboptimal as explained in https://arxiv.org/abs/2006.15704 Sec 3.2.4 |
|
|
""" |
|
|
self.optimizer.step() |
|
|
|
|
|
@property |
|
|
def _data_loader_iter(self): |
|
|
|
|
|
if self._data_loader_iter_obj is None: |
|
|
self._data_loader_iter_obj = iter(self.data_loader) |
|
|
return self._data_loader_iter_obj |
|
|
|
|
|
def reset_data_loader(self, data_loader_builder): |
|
|
""" |
|
|
Delete and replace the current data loader with a new one, which will be created |
|
|
by calling `data_loader_builder` (without argument). |
|
|
""" |
|
|
del self.data_loader |
|
|
data_loader = data_loader_builder() |
|
|
self.data_loader = data_loader |
|
|
self._data_loader_iter_obj = None |
|
|
|
|
|
def _write_metrics( |
|
|
self, |
|
|
loss_dict: Mapping[str, torch.Tensor], |
|
|
data_time: float, |
|
|
prefix: str = "", |
|
|
iter: Optional[int] = None, |
|
|
) -> None: |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
iter = self.iter if iter is None else iter |
|
|
if (iter + 1) % self.gather_metric_period == 0: |
|
|
try: |
|
|
SimpleTrainer.write_metrics(loss_dict, data_time, iter, prefix) |
|
|
except Exception: |
|
|
logger.exception("Exception in writing metrics: ") |
|
|
raise |
|
|
|
|
|
@staticmethod |
|
|
def write_metrics( |
|
|
loss_dict: Mapping[str, torch.Tensor], |
|
|
data_time: float, |
|
|
cur_iter: int, |
|
|
prefix: str = "", |
|
|
) -> None: |
|
|
""" |
|
|
Args: |
|
|
loss_dict (dict): dict of scalar losses |
|
|
data_time (float): time taken by the dataloader iteration |
|
|
prefix (str): prefix for logging keys |
|
|
""" |
|
|
metrics_dict = {k: v.detach().cpu().item() for k, v in loss_dict.items()} |
|
|
metrics_dict["data_time"] = data_time |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
all_metrics_dict = comm.gather(metrics_dict) |
|
|
|
|
|
if comm.is_main_process(): |
|
|
storage = get_event_storage() |
|
|
|
|
|
|
|
|
|
|
|
data_time = np.max([x.pop("data_time") for x in all_metrics_dict]) |
|
|
storage.put_scalar("data_time", data_time, cur_iter=cur_iter) |
|
|
|
|
|
|
|
|
all_metrics_key = [] |
|
|
for metrics_dict in all_metrics_dict: |
|
|
for key in metrics_dict.keys(): |
|
|
if key not in all_metrics_key: |
|
|
all_metrics_key.append(key) |
|
|
metrics_dict = { |
|
|
k: np.mean([x[k] for x in all_metrics_dict if k in x]) for k in all_metrics_key |
|
|
} |
|
|
total_losses_reduced = sum(metrics_dict.values()) |
|
|
if not np.isfinite(total_losses_reduced): |
|
|
raise FloatingPointError( |
|
|
f"Loss became infinite or NaN at iteration={cur_iter}!\n" |
|
|
f"loss_dict = {metrics_dict}" |
|
|
) |
|
|
|
|
|
storage.put_scalar( |
|
|
"{}total_loss".format(prefix), total_losses_reduced, cur_iter=cur_iter |
|
|
) |
|
|
if len(metrics_dict) > 1: |
|
|
storage.put_scalars(cur_iter=cur_iter, **metrics_dict) |
|
|
|
|
|
def state_dict(self): |
|
|
ret = super().state_dict() |
|
|
ret["optimizer"] = self.optimizer.state_dict() |
|
|
return ret |
|
|
|
|
|
def load_state_dict(self, state_dict): |
|
|
super().load_state_dict(state_dict) |
|
|
self.optimizer.load_state_dict(state_dict["optimizer"]) |
|
|
|
|
|
def after_train(self): |
|
|
super().after_train() |
|
|
self.concurrent_executor.shutdown(wait=True) |
|
|
|
|
|
def _write_metrics_common( |
|
|
self, |
|
|
metrics_dict: Mapping[str, torch.Tensor], |
|
|
prefix: str = "", |
|
|
iter: Optional[int] = None, |
|
|
) -> None: |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
iter = self.iter if iter is None else iter |
|
|
if (iter + 1) % self.gather_metric_period == 0: |
|
|
try: |
|
|
SimpleTrainer.write_metrics_common(metrics_dict, iter, prefix) |
|
|
except Exception: |
|
|
logger.exception("Exception in writing metrics: ") |
|
|
raise |
|
|
|
|
|
@staticmethod |
|
|
def write_metrics_common( |
|
|
metrics_dict: Mapping[str, torch.Tensor], |
|
|
cur_iter: int, |
|
|
prefix: str = "", |
|
|
) -> None: |
|
|
""" |
|
|
Args: |
|
|
metrics_dict (dict): dict of scalar losses |
|
|
prefix (str): prefix for logging keys |
|
|
""" |
|
|
metrics_dict = {k: v.detach().cpu().item() for k, v in metrics_dict.items()} |
|
|
all_metrics_dict = comm.gather(metrics_dict) |
|
|
if comm.is_main_process(): |
|
|
storage = get_event_storage() |
|
|
|
|
|
metrics_dict = { |
|
|
k: np.sum([x[k] for x in all_metrics_dict]) for k in all_metrics_dict[0].keys() |
|
|
} |
|
|
|
|
|
if len(metrics_dict) > 1: |
|
|
storage.put_scalars(cur_iter=cur_iter, **metrics_dict) |
|
|
|
|
|
|
|
|
class AMPTrainer(SimpleTrainer): |
|
|
""" |
|
|
Like :class:`SimpleTrainer`, but uses PyTorch's native automatic mixed precision |
|
|
in the training loop. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
model, |
|
|
data_loader, |
|
|
optimizer, |
|
|
gather_metric_period=1, |
|
|
zero_grad_before_forward=False, |
|
|
grad_scaler=None, |
|
|
precision: torch.dtype = torch.float16, |
|
|
log_grad_scaler: bool = False, |
|
|
async_write_metrics=False, |
|
|
): |
|
|
""" |
|
|
Args: |
|
|
model, data_loader, optimizer, gather_metric_period, zero_grad_before_forward, |
|
|
async_write_metrics: same as in :class:`SimpleTrainer`. |
|
|
grad_scaler: torch GradScaler to automatically scale gradients. |
|
|
precision: torch.dtype as the target precision to cast to in computations |
|
|
""" |
|
|
unsupported = "AMPTrainer does not support single-process multi-device training!" |
|
|
if isinstance(model, DistributedDataParallel): |
|
|
assert not (model.device_ids and len(model.device_ids) > 1), unsupported |
|
|
assert not isinstance(model, DataParallel), unsupported |
|
|
|
|
|
super().__init__( |
|
|
model, data_loader, optimizer, gather_metric_period, zero_grad_before_forward |
|
|
) |
|
|
|
|
|
if grad_scaler is None: |
|
|
from torch.cuda.amp import GradScaler |
|
|
|
|
|
grad_scaler = GradScaler() |
|
|
self.grad_scaler = grad_scaler |
|
|
self.precision = precision |
|
|
self.log_grad_scaler = log_grad_scaler |
|
|
|
|
|
def run_step(self): |
|
|
""" |
|
|
Implement the AMP training logic. |
|
|
""" |
|
|
assert self.model.training, "[AMPTrainer] model was changed to eval mode!" |
|
|
assert torch.cuda.is_available(), "[AMPTrainer] CUDA is required for AMP training!" |
|
|
from torch.cuda.amp import autocast |
|
|
|
|
|
start = time.perf_counter() |
|
|
data = next(self._data_loader_iter) |
|
|
data_time = time.perf_counter() - start |
|
|
|
|
|
|
|
|
for d in data: |
|
|
self.dataset_image_counts[self.dataset_names[d.get("dataset_id", 0)]] += 1 |
|
|
self.dataset_object_counts[self.dataset_names[d.get("dataset_id", 0)]] += len( |
|
|
d.get("instances", []) |
|
|
) |
|
|
dataset_image_counts = { |
|
|
f"count_image/{k}": v for k, v in self.dataset_image_counts.items() |
|
|
} |
|
|
dataset_object_counts = { |
|
|
f"count_object/{k}": v for k, v in self.dataset_object_counts.items() |
|
|
} |
|
|
if self.async_write_metrics: |
|
|
|
|
|
self.concurrent_executor.submit( |
|
|
self._write_metrics_common, dataset_image_counts, iter=self.iter |
|
|
) |
|
|
self.concurrent_executor.submit( |
|
|
self._write_metrics_common, dataset_object_counts, iter=self.iter |
|
|
) |
|
|
else: |
|
|
self._write_metrics_common(dataset_image_counts) |
|
|
self._write_metrics_common(dataset_object_counts) |
|
|
|
|
|
|
|
|
if self.zero_grad_before_forward: |
|
|
self.optimizer.zero_grad() |
|
|
with autocast(dtype=self.precision): |
|
|
loss_dict = self.model(data) |
|
|
if isinstance(loss_dict, torch.Tensor): |
|
|
losses = loss_dict |
|
|
loss_dict = {"total_loss": loss_dict} |
|
|
else: |
|
|
losses = sum(loss_dict.values()) |
|
|
|
|
|
if not self.zero_grad_before_forward: |
|
|
self.optimizer.zero_grad() |
|
|
|
|
|
self.grad_scaler.scale(losses).backward() |
|
|
|
|
|
if self.log_grad_scaler: |
|
|
storage = get_event_storage() |
|
|
storage.put_scalar("[metric] grad_scaler", self.grad_scaler.get_scale()) |
|
|
|
|
|
self.after_backward() |
|
|
|
|
|
if self.async_write_metrics: |
|
|
|
|
|
self.concurrent_executor.submit( |
|
|
self._write_metrics, loss_dict, data_time, iter=self.iter |
|
|
) |
|
|
else: |
|
|
self._write_metrics(loss_dict, data_time) |
|
|
|
|
|
self.grad_scaler.step(self.optimizer) |
|
|
self.grad_scaler.update() |
|
|
|
|
|
def state_dict(self): |
|
|
ret = super().state_dict() |
|
|
ret["grad_scaler"] = self.grad_scaler.state_dict() |
|
|
return ret |
|
|
|
|
|
def load_state_dict(self, state_dict): |
|
|
super().load_state_dict(state_dict) |
|
|
self.grad_scaler.load_state_dict(state_dict["grad_scaler"]) |
|
|
|