File size: 3,286 Bytes
663494c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import platform
import shutil
import time
import warnings
import copy
from typing import Any, Dict, List, Optional, Tuple

import torch
from torch.utils.data import DataLoader

import mmcv
from mmcv.runner.epoch_based_runner import EpochBasedRunner
from mmcv.runner.builder import RUNNERS
from mmcv.runner.checkpoint import save_checkpoint
from mmcv.runner.utils import get_host_info


@RUNNERS.register_module()
class EpochBasedRunnerAutoResume(EpochBasedRunner):
    """Epoch-based Runner.

    This runner train models epoch by epoch.
    """

    def train(self, data_loader, **kwargs):
        self.model.train()
        self.mode = 'train'
        self.data_loader = data_loader
        self._max_iters = self._max_epochs * len(self.data_loader)
        self.call_hook('before_train_epoch')
        time.sleep(2)  # Prevent possible deadlock during epoch transition

        # if dataloader has the start_iter to offset iteration which
        # has been done in the last epoch, we apply this offset to skip        
        # print('inner_iter after call_hook', self._inner_iter)
        try:
            iter_offset = data_loader.sampler.start_iter
        except:
            iter_offset = 0

        # iterate the dataloader
        for i, data_batch in enumerate(self.data_loader):
            self.data_batch = data_batch

            # 添加数据检查
            if data_batch is None:
                print(f"[Data Check] data_batch is None at iteration {i}")
                continue
            
            # 检查 data_batch 中的每个 key
            for key, value in data_batch.items():
                if value is None:
                    print(f"[Data Check] data_batch['{key}'] is None at iteration {i}")
                elif isinstance(value, (list, tuple)):
                    for j, item in enumerate(value):
                        if item is None:
                            print(f"[Data Check] data_batch['{key}'][{j}] is None at iteration {i}")
                elif isinstance(value, dict):
                    for sub_key, sub_value in value.items():
                        if sub_value is None:
                            print(f"[Data Check] data_batch['{key}']['{sub_key}'] is None at iteration {i}")

            # very slow approach!!!, still iterate the dataloader
            # if i < self._inner_iter:
            #     self.logger.info(f"Skip iter in the last training job: {i}")
            #     del self.data_batch
            #     # self._iter += 1                
            #     continue

            # add offset to handle auto-resume, break if finishing all data samples
            # only in this particular epoch to be resumed
            self._inner_iter = i + iter_offset

            self.call_hook('before_train_iter')
            self.run_iter(data_batch, train_mode=True, **kwargs)
            self.call_hook('after_train_iter')
            del self.data_batch
            self._iter += 1

            # reset so that next epoch will not skip data sample
            if self._inner_iter+1 == len(self.data_loader):
                self._inner_iter = 0
                break

        self.call_hook('after_train_epoch')
        self._epoch += 1