|
|
|
|
|
import os.path as osp |
|
|
import platform |
|
|
import shutil |
|
|
import time |
|
|
import warnings |
|
|
import copy |
|
|
from typing import Any, Dict, List, Optional, Tuple |
|
|
|
|
|
import torch |
|
|
from torch.utils.data import DataLoader |
|
|
|
|
|
import mmcv |
|
|
from mmcv.runner.epoch_based_runner import EpochBasedRunner |
|
|
from mmcv.runner.builder import RUNNERS |
|
|
from mmcv.runner.checkpoint import save_checkpoint |
|
|
from mmcv.runner.utils import get_host_info |
|
|
|
|
|
|
|
|
@RUNNERS.register_module() |
|
|
class EpochBasedRunnerAutoResume(EpochBasedRunner): |
|
|
"""Epoch-based Runner. |
|
|
|
|
|
This runner train models epoch by epoch. |
|
|
""" |
|
|
|
|
|
def train(self, data_loader, **kwargs): |
|
|
self.model.train() |
|
|
self.mode = 'train' |
|
|
self.data_loader = data_loader |
|
|
self._max_iters = self._max_epochs * len(self.data_loader) |
|
|
self.call_hook('before_train_epoch') |
|
|
time.sleep(2) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
iter_offset = data_loader.sampler.start_iter |
|
|
except: |
|
|
iter_offset = 0 |
|
|
|
|
|
|
|
|
for i, data_batch in enumerate(self.data_loader): |
|
|
self.data_batch = data_batch |
|
|
|
|
|
|
|
|
if data_batch is None: |
|
|
print(f"[Data Check] data_batch is None at iteration {i}") |
|
|
continue |
|
|
|
|
|
|
|
|
for key, value in data_batch.items(): |
|
|
if value is None: |
|
|
print(f"[Data Check] data_batch['{key}'] is None at iteration {i}") |
|
|
elif isinstance(value, (list, tuple)): |
|
|
for j, item in enumerate(value): |
|
|
if item is None: |
|
|
print(f"[Data Check] data_batch['{key}'][{j}] is None at iteration {i}") |
|
|
elif isinstance(value, dict): |
|
|
for sub_key, sub_value in value.items(): |
|
|
if sub_value is None: |
|
|
print(f"[Data Check] data_batch['{key}']['{sub_key}'] is None at iteration {i}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self._inner_iter = i + iter_offset |
|
|
|
|
|
self.call_hook('before_train_iter') |
|
|
self.run_iter(data_batch, train_mode=True, **kwargs) |
|
|
self.call_hook('after_train_iter') |
|
|
del self.data_batch |
|
|
self._iter += 1 |
|
|
|
|
|
|
|
|
if self._inner_iter+1 == len(self.data_loader): |
|
|
self._inner_iter = 0 |
|
|
break |
|
|
|
|
|
self.call_hook('after_train_epoch') |
|
|
self._epoch += 1 |
|
|
|