# Copyright (c) OpenMMLab. All rights reserved. import os.path as osp import platform import shutil import time import warnings import copy from typing import Any, Dict, List, Optional, Tuple import torch from torch.utils.data import DataLoader import mmcv from mmcv.runner.epoch_based_runner import EpochBasedRunner from mmcv.runner.builder import RUNNERS from mmcv.runner.checkpoint import save_checkpoint from mmcv.runner.utils import get_host_info @RUNNERS.register_module() class EpochBasedRunnerAutoResume(EpochBasedRunner): """Epoch-based Runner. This runner train models epoch by epoch. """ def train(self, data_loader, **kwargs): self.model.train() self.mode = 'train' self.data_loader = data_loader self._max_iters = self._max_epochs * len(self.data_loader) self.call_hook('before_train_epoch') time.sleep(2) # Prevent possible deadlock during epoch transition # if dataloader has the start_iter to offset iteration which # has been done in the last epoch, we apply this offset to skip # print('inner_iter after call_hook', self._inner_iter) try: iter_offset = data_loader.sampler.start_iter except: iter_offset = 0 # iterate the dataloader for i, data_batch in enumerate(self.data_loader): self.data_batch = data_batch # 添加数据检查 if data_batch is None: print(f"[Data Check] data_batch is None at iteration {i}") continue # 检查 data_batch 中的每个 key for key, value in data_batch.items(): if value is None: print(f"[Data Check] data_batch['{key}'] is None at iteration {i}") elif isinstance(value, (list, tuple)): for j, item in enumerate(value): if item is None: print(f"[Data Check] data_batch['{key}'][{j}] is None at iteration {i}") elif isinstance(value, dict): for sub_key, sub_value in value.items(): if sub_value is None: print(f"[Data Check] data_batch['{key}']['{sub_key}'] is None at iteration {i}") # very slow approach!!!, still iterate the dataloader # if i < self._inner_iter: # self.logger.info(f"Skip iter in the last training job: {i}") # del self.data_batch # # self._iter += 1 # continue # add offset to handle auto-resume, break if finishing all data samples # only in this particular epoch to be resumed self._inner_iter = i + iter_offset self.call_hook('before_train_iter') self.run_iter(data_batch, train_mode=True, **kwargs) self.call_hook('after_train_iter') del self.data_batch self._iter += 1 # reset so that next epoch will not skip data sample if self._inner_iter+1 == len(self.data_loader): self._inner_iter = 0 break self.call_hook('after_train_epoch') self._epoch += 1