|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import bisect |
|
|
import logging |
|
|
import time |
|
|
from typing import Dict, List, Optional, Sequence, Tuple, Union |
|
|
|
|
|
import torch |
|
|
from torch.utils.data import DataLoader |
|
|
|
|
|
from mmengine.evaluator import Evaluator |
|
|
from mmengine.logging import print_log |
|
|
from mmengine.registry import LOOPS |
|
|
from .amp import autocast |
|
|
from .base_loop import BaseLoop |
|
|
from .utils import calc_dynamic_intervals |
|
|
import socket |
|
|
|
|
|
|
|
|
@LOOPS.register_module() |
|
|
class EpochBasedTrainLoop(BaseLoop): |
|
|
"""Loop for epoch-based training. |
|
|
|
|
|
Args: |
|
|
runner (Runner): A reference of runner. |
|
|
dataloader (Dataloader or dict): A dataloader object or a dict to |
|
|
build a dataloader. |
|
|
max_epochs (int): Total training epochs. |
|
|
val_begin (int): The epoch that begins validating. |
|
|
Defaults to 1. |
|
|
val_interval (int): Validation interval. Defaults to 1. |
|
|
dynamic_intervals (List[Tuple[int, int]], optional): The |
|
|
first element in the tuple is a milestone and the second |
|
|
element is a interval. The interval is used after the |
|
|
corresponding milestone. Defaults to None. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
runner, |
|
|
dataloader: Union[DataLoader, Dict], |
|
|
max_epochs: int, |
|
|
val_begin: int = 1, |
|
|
val_interval: int = 1, |
|
|
dynamic_intervals: Optional[List[Tuple[int, int]]] = None) -> None: |
|
|
super().__init__(runner, dataloader) |
|
|
self._max_epochs = int(max_epochs) |
|
|
assert self._max_epochs == max_epochs, \ |
|
|
f'`max_epochs` should be a integer number, but get {max_epochs}.' |
|
|
self._max_iters = self._max_epochs * len(self.dataloader) |
|
|
self._epoch = 0 |
|
|
self._iter = 0 |
|
|
self.val_begin = val_begin |
|
|
self.val_interval = val_interval |
|
|
|
|
|
|
|
|
self.stop_training = False |
|
|
if hasattr(self.dataloader.dataset, 'metainfo'): |
|
|
self.runner.visualizer.dataset_meta = \ |
|
|
self.dataloader.dataset.metainfo |
|
|
else: |
|
|
print_log( |
|
|
f'Dataset {self.dataloader.dataset.__class__.__name__} has no ' |
|
|
'metainfo. ``dataset_meta`` in visualizer will be ' |
|
|
'None.', |
|
|
logger='current', |
|
|
level=logging.WARNING) |
|
|
|
|
|
self.dynamic_milestones, self.dynamic_intervals = \ |
|
|
calc_dynamic_intervals( |
|
|
self.val_interval, dynamic_intervals) |
|
|
|
|
|
@property |
|
|
def max_epochs(self): |
|
|
"""int: Total epochs to train model.""" |
|
|
return self._max_epochs |
|
|
|
|
|
@property |
|
|
def max_iters(self): |
|
|
"""int: Total iterations to train model.""" |
|
|
return self._max_iters |
|
|
|
|
|
@property |
|
|
def epoch(self): |
|
|
"""int: Current epoch.""" |
|
|
return self._epoch |
|
|
|
|
|
@property |
|
|
def iter(self): |
|
|
"""int: Current iteration.""" |
|
|
return self._iter |
|
|
|
|
|
def run(self) -> torch.nn.Module: |
|
|
"""Launch training.""" |
|
|
self.runner.call_hook('before_train') |
|
|
|
|
|
while self._epoch < self._max_epochs and not self.stop_training: |
|
|
self.run_epoch() |
|
|
|
|
|
self._decide_current_val_interval() |
|
|
if (self.runner.val_loop is not None |
|
|
and self._epoch >= self.val_begin |
|
|
and self._epoch % self.val_interval == 0): |
|
|
self.runner.val_loop.run() |
|
|
|
|
|
self.runner.call_hook('after_train') |
|
|
return self.runner.model |
|
|
|
|
|
def run_epoch(self) -> None: |
|
|
"""Iterate one epoch.""" |
|
|
self.runner.call_hook('before_train_epoch') |
|
|
|
|
|
self.runner.model.train() |
|
|
|
|
|
for idx, data_batch in enumerate(self.dataloader): |
|
|
self.run_iter(idx, data_batch) |
|
|
|
|
|
self.runner.call_hook('after_train_epoch') |
|
|
self._epoch += 1 |
|
|
|
|
|
def run_iter(self, idx, data_batch: Sequence[dict]) -> None: |
|
|
"""Iterate one min-batch. |
|
|
|
|
|
Args: |
|
|
data_batch (Sequence[dict]): Batch of data from dataloader. |
|
|
""" |
|
|
self.runner.call_hook( |
|
|
'before_train_iter', batch_idx=idx, data_batch=data_batch) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
outputs = self.runner.model.train_step( |
|
|
data_batch, optim_wrapper=self.runner.optim_wrapper) |
|
|
|
|
|
self.runner.call_hook( |
|
|
'after_train_iter', |
|
|
batch_idx=idx, |
|
|
data_batch=data_batch, |
|
|
outputs=outputs) |
|
|
self._iter += 1 |
|
|
|
|
|
def _decide_current_val_interval(self) -> None: |
|
|
"""Dynamically modify the ``val_interval``.""" |
|
|
step = bisect.bisect(self.dynamic_milestones, (self.epoch + 1)) |
|
|
self.val_interval = self.dynamic_intervals[step - 1] |
|
|
|
|
|
|
|
|
class _InfiniteDataloaderIterator: |
|
|
"""An infinite dataloader iterator wrapper for IterBasedTrainLoop. |
|
|
|
|
|
It resets the dataloader to continue iterating when the iterator has |
|
|
iterated over all the data. However, this approach is not efficient, as the |
|
|
workers need to be restarted every time the dataloader is reset. It is |
|
|
recommended to use `mmengine.dataset.InfiniteSampler` to enable the |
|
|
dataloader to iterate infinitely. |
|
|
""" |
|
|
|
|
|
def __init__(self, dataloader: DataLoader) -> None: |
|
|
self._dataloader = dataloader |
|
|
self._iterator = iter(self._dataloader) |
|
|
self._epoch = 0 |
|
|
|
|
|
def __iter__(self): |
|
|
return self |
|
|
|
|
|
def __next__(self) -> Sequence[dict]: |
|
|
try: |
|
|
data = next(self._iterator) |
|
|
except StopIteration: |
|
|
print_log( |
|
|
'Reach the end of the dataloader, it will be ' |
|
|
'restarted and continue to iterate. It is ' |
|
|
'recommended to use ' |
|
|
'`mmengine.dataset.InfiniteSampler` to enable the ' |
|
|
'dataloader to iterate infinitely.', |
|
|
logger='current', |
|
|
level=logging.WARNING) |
|
|
self._epoch += 1 |
|
|
if hasattr(self._dataloader, 'sampler') and hasattr( |
|
|
self._dataloader.sampler, 'set_epoch'): |
|
|
|
|
|
|
|
|
self._dataloader.sampler.set_epoch(self._epoch) |
|
|
|
|
|
elif hasattr(self._dataloader, 'batch_sampler') and hasattr( |
|
|
self._dataloader.batch_sampler.sampler, 'set_epoch'): |
|
|
|
|
|
|
|
|
|
|
|
self._dataloader.batch_sampler.sampler.set_epoch(self._epoch) |
|
|
time.sleep(2) |
|
|
self._iterator = iter(self._dataloader) |
|
|
data = next(self._iterator) |
|
|
return data |
|
|
|
|
|
|
|
|
@LOOPS.register_module() |
|
|
class IterBasedTrainLoop(BaseLoop): |
|
|
"""Loop for iter-based training. |
|
|
|
|
|
Args: |
|
|
runner (Runner): A reference of runner. |
|
|
dataloader (Dataloader or dict): A dataloader object or a dict to |
|
|
build a dataloader. |
|
|
max_iters (int): Total training iterations. |
|
|
val_begin (int): The iteration that begins validating. |
|
|
Defaults to 1. |
|
|
val_interval (int): Validation interval. Defaults to 1000. |
|
|
dynamic_intervals (List[Tuple[int, int]], optional): The |
|
|
first element in the tuple is a milestone and the second |
|
|
element is a interval. The interval is used after the |
|
|
corresponding milestone. Defaults to None. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
runner, |
|
|
dataloader: Union[DataLoader, Dict], |
|
|
max_iters: int, |
|
|
val_begin: int = 1, |
|
|
val_interval: int = 1000, |
|
|
dynamic_intervals: Optional[List[Tuple[int, int]]] = None) -> None: |
|
|
super().__init__(runner, dataloader) |
|
|
self._max_iters = int(max_iters) |
|
|
assert self._max_iters == max_iters, \ |
|
|
f'`max_iters` should be a integer number, but get {max_iters}' |
|
|
self._max_epochs = 1 |
|
|
self._epoch = 0 |
|
|
self._iter = 0 |
|
|
self.val_begin = val_begin |
|
|
self.val_interval = val_interval |
|
|
|
|
|
|
|
|
self.stop_training = False |
|
|
if hasattr(self.dataloader.dataset, 'metainfo'): |
|
|
self.runner.visualizer.dataset_meta = \ |
|
|
self.dataloader.dataset.metainfo |
|
|
else: |
|
|
print_log( |
|
|
f'Dataset {self.dataloader.dataset.__class__.__name__} has no ' |
|
|
'metainfo. ``dataset_meta`` in visualizer will be ' |
|
|
'None.', |
|
|
logger='current', |
|
|
level=logging.WARNING) |
|
|
|
|
|
self.dataloader_iterator = _InfiniteDataloaderIterator(self.dataloader) |
|
|
|
|
|
self.dynamic_milestones, self.dynamic_intervals = \ |
|
|
calc_dynamic_intervals( |
|
|
self.val_interval, dynamic_intervals) |
|
|
|
|
|
@property |
|
|
def max_epochs(self): |
|
|
"""int: Total epochs to train model.""" |
|
|
return self._max_epochs |
|
|
|
|
|
@property |
|
|
def max_iters(self): |
|
|
"""int: Total iterations to train model.""" |
|
|
return self._max_iters |
|
|
|
|
|
@property |
|
|
def epoch(self): |
|
|
"""int: Current epoch.""" |
|
|
return self._epoch |
|
|
|
|
|
@property |
|
|
def iter(self): |
|
|
"""int: Current iteration.""" |
|
|
return self._iter |
|
|
|
|
|
def run(self) -> None: |
|
|
"""Launch training.""" |
|
|
self.runner.call_hook('before_train') |
|
|
|
|
|
|
|
|
self.runner.call_hook('before_train_epoch') |
|
|
while self._iter < self._max_iters and not self.stop_training: |
|
|
self.runner.model.train() |
|
|
|
|
|
data_batch = next(self.dataloader_iterator) |
|
|
self.run_iter(data_batch) |
|
|
|
|
|
self._decide_current_val_interval() |
|
|
if (self.runner.val_loop is not None |
|
|
and self._iter >= self.val_begin |
|
|
and self._iter % self.val_interval == 0): |
|
|
self.runner.val_loop.run() |
|
|
|
|
|
self.runner.call_hook('after_train_epoch') |
|
|
self.runner.call_hook('after_train') |
|
|
return self.runner.model |
|
|
|
|
|
def run_iter(self, data_batch: Sequence[dict]) -> None: |
|
|
"""Iterate one mini-batch. |
|
|
|
|
|
Args: |
|
|
data_batch (Sequence[dict]): Batch of data from dataloader. |
|
|
""" |
|
|
self.runner.call_hook( |
|
|
'before_train_iter', batch_idx=self._iter, data_batch=data_batch) |
|
|
|
|
|
|
|
|
|
|
|
outputs = self.runner.model.train_step( |
|
|
data_batch, optim_wrapper=self.runner.optim_wrapper) |
|
|
|
|
|
self.runner.call_hook( |
|
|
'after_train_iter', |
|
|
batch_idx=self._iter, |
|
|
data_batch=data_batch, |
|
|
outputs=outputs) |
|
|
self._iter += 1 |
|
|
|
|
|
def _decide_current_val_interval(self) -> None: |
|
|
"""Dynamically modify the ``val_interval``.""" |
|
|
step = bisect.bisect(self.dynamic_milestones, (self._iter + 1)) |
|
|
self.val_interval = self.dynamic_intervals[step - 1] |
|
|
|
|
|
|
|
|
@LOOPS.register_module() |
|
|
class ValLoop(BaseLoop): |
|
|
"""Loop for validation. |
|
|
|
|
|
Args: |
|
|
runner (Runner): A reference of runner. |
|
|
dataloader (Dataloader or dict): A dataloader object or a dict to |
|
|
build a dataloader. |
|
|
evaluator (Evaluator or dict or list): Used for computing metrics. |
|
|
fp16 (bool): Whether to enable fp16 validation. Defaults to |
|
|
False. |
|
|
""" |
|
|
|
|
|
def __init__(self, |
|
|
runner, |
|
|
dataloader: Union[DataLoader, Dict], |
|
|
evaluator: Union[Evaluator, Dict, List], |
|
|
fp16: bool = False) -> None: |
|
|
super().__init__(runner, dataloader) |
|
|
|
|
|
if isinstance(evaluator, (dict, list)): |
|
|
self.evaluator = runner.build_evaluator(evaluator) |
|
|
else: |
|
|
assert isinstance(evaluator, Evaluator), ( |
|
|
'evaluator must be one of dict, list or Evaluator instance, ' |
|
|
f'but got {type(evaluator)}.') |
|
|
self.evaluator = evaluator |
|
|
if hasattr(self.dataloader.dataset, 'metainfo'): |
|
|
self.evaluator.dataset_meta = self.dataloader.dataset.metainfo |
|
|
self.runner.visualizer.dataset_meta = \ |
|
|
self.dataloader.dataset.metainfo |
|
|
else: |
|
|
print_log( |
|
|
f'Dataset {self.dataloader.dataset.__class__.__name__} has no ' |
|
|
'metainfo. ``dataset_meta`` in evaluator, metric and ' |
|
|
'visualizer will be None.', |
|
|
logger='current', |
|
|
level=logging.WARNING) |
|
|
self.fp16 = fp16 |
|
|
|
|
|
def run(self) -> dict: |
|
|
"""Launch validation.""" |
|
|
self.runner.call_hook('before_val') |
|
|
self.runner.call_hook('before_val_epoch') |
|
|
self.runner.model.eval() |
|
|
for idx, data_batch in enumerate(self.dataloader): |
|
|
self.run_iter(idx, data_batch) |
|
|
|
|
|
|
|
|
metrics = self.evaluator.evaluate(len(self.dataloader.dataset)) |
|
|
self.runner.call_hook('after_val_epoch', metrics=metrics) |
|
|
self.runner.call_hook('after_val') |
|
|
return metrics |
|
|
|
|
|
@torch.no_grad() |
|
|
def run_iter(self, idx, data_batch: Sequence[dict]): |
|
|
"""Iterate one mini-batch. |
|
|
|
|
|
Args: |
|
|
data_batch (Sequence[dict]): Batch of data |
|
|
from dataloader. |
|
|
""" |
|
|
self.runner.call_hook( |
|
|
'before_val_iter', batch_idx=idx, data_batch=data_batch) |
|
|
|
|
|
with autocast(enabled=self.fp16): |
|
|
outputs = self.runner.model.val_step(data_batch) |
|
|
self.evaluator.process(data_samples=outputs, data_batch=data_batch) |
|
|
self.runner.call_hook( |
|
|
'after_val_iter', |
|
|
batch_idx=idx, |
|
|
data_batch=data_batch, |
|
|
outputs=outputs) |
|
|
|
|
|
|
|
|
@LOOPS.register_module() |
|
|
class TestLoop(BaseLoop): |
|
|
"""Loop for test. |
|
|
|
|
|
Args: |
|
|
runner (Runner): A reference of runner. |
|
|
dataloader (Dataloader or dict): A dataloader object or a dict to |
|
|
build a dataloader. |
|
|
evaluator (Evaluator or dict or list): Used for computing metrics. |
|
|
fp16 (bool): Whether to enable fp16 testing. Defaults to |
|
|
False. |
|
|
""" |
|
|
|
|
|
def __init__(self, |
|
|
runner, |
|
|
dataloader: Union[DataLoader, Dict], |
|
|
evaluator: Union[Evaluator, Dict, List], |
|
|
fp16: bool = False): |
|
|
super().__init__(runner, dataloader) |
|
|
|
|
|
if isinstance(evaluator, dict) or isinstance(evaluator, list): |
|
|
self.evaluator = runner.build_evaluator(evaluator) |
|
|
else: |
|
|
self.evaluator = evaluator |
|
|
if hasattr(self.dataloader.dataset, 'metainfo'): |
|
|
self.evaluator.dataset_meta = self.dataloader.dataset.metainfo |
|
|
self.runner.visualizer.dataset_meta = \ |
|
|
self.dataloader.dataset.metainfo |
|
|
else: |
|
|
print_log( |
|
|
f'Dataset {self.dataloader.dataset.__class__.__name__} has no ' |
|
|
'metainfo. ``dataset_meta`` in evaluator, metric and ' |
|
|
'visualizer will be None.', |
|
|
logger='current', |
|
|
level=logging.WARNING) |
|
|
self.fp16 = fp16 |
|
|
|
|
|
def run(self) -> dict: |
|
|
"""Launch test.""" |
|
|
self.runner.call_hook('before_test') |
|
|
self.runner.call_hook('before_test_epoch') |
|
|
self.runner.model.eval() |
|
|
for idx, data_batch in enumerate(self.dataloader): |
|
|
self.run_iter(idx, data_batch) |
|
|
|
|
|
|
|
|
metrics = self.evaluator.evaluate(len(self.dataloader.dataset)) |
|
|
self.runner.call_hook('after_test_epoch', metrics=metrics) |
|
|
self.runner.call_hook('after_test') |
|
|
return metrics |
|
|
|
|
|
@torch.no_grad() |
|
|
def run_iter(self, idx, data_batch: Sequence[dict]) -> None: |
|
|
"""Iterate one mini-batch. |
|
|
|
|
|
Args: |
|
|
data_batch (Sequence[dict]): Batch of data from dataloader. |
|
|
""" |
|
|
self.runner.call_hook( |
|
|
'before_test_iter', batch_idx=idx, data_batch=data_batch) |
|
|
|
|
|
with autocast(enabled=self.fp16): |
|
|
outputs = self.runner.model.test_step(data_batch) |
|
|
self.evaluator.process(data_samples=outputs, data_batch=data_batch) |
|
|
self.runner.call_hook( |
|
|
'after_test_iter', |
|
|
batch_idx=idx, |
|
|
data_batch=data_batch, |
|
|
outputs=outputs) |
|
|
|