File size: 10,572 Bytes
33569f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
# python imports
import argparse
import os
import time
import datetime
import yaml
import json
from pprint import pprint

# torch imports
import torch

import torch.nn as nn
import torch.utils.data
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm, trange
# for visualization
# from torch.utils.tensorboard import SummaryWriter

# our code
from libs.core import load_config
from libs.datasets import make_dataset, make_data_loader
from libs.modeling import make_meta_arch
from libs.utils import (train_one_epoch, valid_one_epoch, ANETdetection,
                        save_checkpoint, make_optimizer, make_scheduler,
                        fix_random_seed, ModelEma, display_python_performance, get_average_performance, merge_ResultSaveObj)
import itertools
import collections
from IPython import embed

def load_json(filename):
    with open(filename, encoding='utf8') as fr:
        return json.load(fr)

from terminaltables import AsciiTable

################################################################################
def main(args):
    """main function that handles training / inference"""

    """1. setup parameters / folders"""
    # parse args
    args.start_epoch = 0
    if os.path.isfile(args.config):
        cfg = load_config(args.config)
    else:
        raise ValueError("Config file does not exist.")
    # pprint(cfg)

    # tensorboard writer
    tb_writer = None

    # fix the random seeds (this will fix everything)
    rng_generator = fix_random_seed(cfg['init_rand_seed'], include_cuda=True)

    # re-scale learning rate / # workers based on number of GPUs
    cfg['opt']["learning_rate"] *= len(cfg['devices'])
    cfg['loader']['num_workers'] *= len(cfg['devices'])

    cfg['dataset']['max_seq_len'] = cfg['dataset']['num_frames']
    cfg['save_root'] = os.path.join('model_ckpt')
    """2. create dataset / dataloader"""
    train_dataset = make_dataset(
        cfg['dataset_name'], True, cfg['train_split_list'], **cfg['dataset']
    )
    # update cfg based on dataset attributes (fix to epic-kitchens)
    # train_db_vars = train_dataset.get_attributes()
    cfg['model']['train_cfg']['head_empty_cls'] = []

    # data loaders
    train_loader = make_data_loader(
        train_dataset, True, rng_generator, **cfg['loader'])
    """2. create dataset / dataloader"""
    val_dataset_list = []
    val_loader_list = []

    for val_split in cfg['val_split_list']:
        val_dataset = make_dataset(
            cfg['dataset_name'], False, val_split, **cfg['dataset']
        )
        val_loader = make_data_loader(
            val_dataset, False, None, 1, cfg['loader']['num_workers']
        )
        val_dataset_list.append(val_dataset)
        val_loader_list.append(val_loader)

    """3. create model, optimizer, and scheduler"""
    # model
    model = make_meta_arch(cfg['model_name'], **cfg['model'])
    # model.load_state_dict(torch.load(os.path.join(cfg['save_root'], '001.pth')))
    # embed()
    # not ideal for multi GPU training, ok for now
    model = nn.DataParallel(model, device_ids=cfg['devices'])
    # optimizer
    optimizer = make_optimizer(model, cfg['opt'])
    # schedule
    num_iters_per_epoch = len(train_loader)
    scheduler = make_scheduler(optimizer, cfg['opt'], num_iters_per_epoch)

    """4. Resume from model / Misc"""

    args.print_freq = 100
    det_eval, output_file = None, None
    """5. Test the model"""

    """4. training / validation loop"""
    print("\nStart training model {:s} ...".format(cfg['model_name']))

    # start training
    max_epochs = cfg['opt'].get(
        'early_stop_epochs',
        cfg['opt']['epochs'] + cfg['opt']['warmup_epochs']
    )

    model_ema = None
    new_best_per_split = None
    cfg['train_split'] = cfg['train_split_list'][0]
    cfg['test_split_list'] = cfg['val_split_list']
    for epoch in range(args.start_epoch, max_epochs):
        # train for one epoch
        args.print_freq = 50
        train_one_epoch(
            train_loader,
            model,
            optimizer,
            scheduler,
            epoch,
            model_ema = model_ema,
            clip_grad_l2norm = cfg['train_cfg']['clip_grad_l2norm'],
            tb_writer=tb_writer,
            print_freq=args.print_freq
        )

        if (epoch % cfg['opt']['valid_epoch'] != 0 or epoch < cfg['opt']['start_test_epoch']) and epoch != max_epochs - 1:
            continue
        args.print_freq = 2000
        print('=' * 100)
        print(f'[Test]: Epoch {epoch} started')
        print('=' * 100)
        split_results_dict = {tmp_k: [] for tmp_k in cfg['val_split_list']}
        split_results_obj_dict = {tmp_k: [] for tmp_k in cfg['val_split_list']}
        for val_split, val_loader in zip(cfg['val_split_list'], val_loader_list):
            split_output_file = output_file
            _, acc_results, result_save_obj_dict = valid_one_epoch(
                val_loader,
                model,
                -1,
                evaluator=det_eval,
                output_file=split_output_file,
                ext_score_file=cfg['test_cfg']['ext_score_file'],
                tb_writer=None,
                print_freq=args.print_freq,
            )


            # 计算平均性能指标
            for local_weight in result_save_obj_dict:

                val_results_obj = result_save_obj_dict[local_weight]
                split_results_obj_dict[val_split].append(val_results_obj)


        merge_keys = cfg['val_split_list']
        new_split_results_dict = collections.defaultdict(list)

        in_domain = [tmp_itm.replace('train', 'test') for tmp_itm in cfg['train_split_list'] if 'real' not in tmp_itm]
        out_domain_2 = [tmp_itm for tmp_itm in cfg['test_split_list'] if tmp_itm not in in_domain]

        domain_name_list = ['in_domain', 'out_domain']
        tqdm_list = [in_domain, out_domain_2]

        domain_name_id = -1
        start_add = len(split_results_obj_dict)
        for merge_combo in tqdm_list:
            domain_name_id += 1
            if len(merge_combo) <= 0:
                continue
            merge_key_name = "+".join(merge_combo)
            merge_result_list = []
            # embed()
            for merge_idx in range(len(split_results_obj_dict[merge_combo[0]])):
                merge_objs = [split_results_obj_dict[k][merge_idx] for k in merge_combo]
                merge_obj = merge_ResultSaveObj(merge_objs)
                merge_result_list.append(merge_obj)
            split_results_obj_dict[domain_name_list[domain_name_id]+f' ({merge_key_name})'] = merge_result_list


        tqdm_list = tqdm(split_results_obj_dict.items())
        start_id = -1
        for merge_k, merge_v_list in tqdm_list:
            start_id += 1
            if start_id < start_add and cfg['test_cfg']['skip_separate_flag']:
                continue
            for merge_v in merge_v_list:
                new_split_results_dict[merge_k].append(merge_v.eval())

        # embed()
        for test_split_key in new_split_results_dict:
            if 'in_domain ' in test_split_key:
                break
        assert 'in_domain ' in test_split_key
        for test_split_key_assist in new_split_results_dict:
            if 'out_domain ' in test_split_key_assist:
                break
        assert 'out_domain ' in test_split_key_assist

        if new_best_per_split is None:
            new_best_per_split = {
                val_split: {
                    "best_avg": float("-inf"),
                    "best_epoch": None,
                    "best_local_weight": None,
                    "best_results": None,
                }
                for val_split in new_split_results_dict
            }

        local_weight_list = list(result_save_obj_dict.keys())
        print('='*100)
        num_train_samples = len(train_dataset)
        print(f"Current Validation Results | Epoch {epoch} | Trained on {cfg['train_split_list']} ({num_train_samples} samples)")
        print('=' * 100)
        for merge_k, merge_v_list in new_split_results_dict.items():
            for merge_v, local_weight in zip(merge_v_list, local_weight_list):
                avg_perf = get_average_performance(merge_v)
                print(f"Results for {merge_k}: avg={avg_perf:.4f} | epoch {epoch} | local_weight {local_weight}")
                print(display_python_performance(merge_v))
                if avg_perf > new_best_per_split[merge_k]["best_avg"]:
                    new_best_per_split[merge_k]["best_avg"] = avg_perf
                    new_best_per_split[merge_k]["best_epoch"] = epoch
                    new_best_per_split[merge_k]["best_local_weight"] = local_weight
                    new_best_per_split[merge_k]["best_results"] = merge_v
                    # print(f"Update best results for {merge_k}: avg={avg_perf:.4f} | epoch {epoch} | local_weight {local_weight}")
                    print(f"Update best results")
                print()

        # print('='*100)
        num_train_samples = len(train_dataset)
        print('='*100)
        print(f"Best Validation Results | Epoch {epoch} | Trained on {cfg['train_split_list']} ({num_train_samples} samples)")
        print('='*100)
        for val_split in new_best_per_split:
            rec = new_best_per_split[val_split]
            print(f"Best for {val_split}:\nR1 = {rec['best_avg']:.4f}\nepoch {rec['best_epoch']} | local_weight {rec['best_local_weight']}")
            print(display_python_performance(rec["best_results"]))
            print()

################################################################################
if __name__ == '__main__':
    """Entry Point"""
    # the arg parser
    parser = argparse.ArgumentParser(
      description='Train a point-based transformer for action localization')
    parser.add_argument('--config', metavar='DIR',
                        help='path to a config file')
    parser.add_argument('-p', '--print-freq', default=10, type=int,
                        help='print frequency (default: 10 iterations)')
    parser.add_argument('-c', '--ckpt-freq', default=5, type=int,
                        help='checkpoint frequency (default: every 5 epochs)')
    parser.add_argument('--output', default='', type=str,
                        help='name of exp folder (default: none)')
    parser.add_argument('--resume', default='', type=str, metavar='PATH',
                        help='path to a checkpoint (default: none)')
    parser.add_argument('--tag', default='baseline', type=str, help='experiment tag')
    args = parser.parse_args()

    main(args)