unknownuser6666's picture
Upload folder using huggingface_hub
663494c verified
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import platform
import shutil
import time
import warnings
import copy
from typing import Any, Dict, List, Optional, Tuple
import torch
from torch.utils.data import DataLoader
import mmcv
from mmcv.runner.epoch_based_runner import EpochBasedRunner
from mmcv.runner.builder import RUNNERS
from mmcv.runner.checkpoint import save_checkpoint
from mmcv.runner.utils import get_host_info
@RUNNERS.register_module()
class EpochBasedRunnerAutoResume(EpochBasedRunner):
"""Epoch-based Runner.
This runner train models epoch by epoch.
"""
def train(self, data_loader, **kwargs):
self.model.train()
self.mode = 'train'
self.data_loader = data_loader
self._max_iters = self._max_epochs * len(self.data_loader)
self.call_hook('before_train_epoch')
time.sleep(2) # Prevent possible deadlock during epoch transition
# if dataloader has the start_iter to offset iteration which
# has been done in the last epoch, we apply this offset to skip
# print('inner_iter after call_hook', self._inner_iter)
try:
iter_offset = data_loader.sampler.start_iter
except:
iter_offset = 0
# iterate the dataloader
for i, data_batch in enumerate(self.data_loader):
self.data_batch = data_batch
# 添加数据检查
if data_batch is None:
print(f"[Data Check] data_batch is None at iteration {i}")
continue
# 检查 data_batch 中的每个 key
for key, value in data_batch.items():
if value is None:
print(f"[Data Check] data_batch['{key}'] is None at iteration {i}")
elif isinstance(value, (list, tuple)):
for j, item in enumerate(value):
if item is None:
print(f"[Data Check] data_batch['{key}'][{j}] is None at iteration {i}")
elif isinstance(value, dict):
for sub_key, sub_value in value.items():
if sub_value is None:
print(f"[Data Check] data_batch['{key}']['{sub_key}'] is None at iteration {i}")
# very slow approach!!!, still iterate the dataloader
# if i < self._inner_iter:
# self.logger.info(f"Skip iter in the last training job: {i}")
# del self.data_batch
# # self._iter += 1
# continue
# add offset to handle auto-resume, break if finishing all data samples
# only in this particular epoch to be resumed
self._inner_iter = i + iter_offset
self.call_hook('before_train_iter')
self.run_iter(data_batch, train_mode=True, **kwargs)
self.call_hook('after_train_iter')
del self.data_batch
self._iter += 1
# reset so that next epoch will not skip data sample
if self._inner_iter+1 == len(self.data_loader):
self._inner_iter = 0
break
self.call_hook('after_train_epoch')
self._epoch += 1