| |
| import argparse |
| import os |
| import time |
| import datetime |
| import yaml |
| import json |
| from pprint import pprint |
|
|
| |
| import torch |
|
|
| import torch.nn as nn |
| import torch.utils.data |
| from torch.utils.tensorboard import SummaryWriter |
| from tqdm import tqdm, trange |
| |
| |
|
|
| |
| from libs.core import load_config |
| from libs.datasets import make_dataset, make_data_loader |
| from libs.modeling import make_meta_arch |
| from libs.utils import (train_one_epoch, valid_one_epoch, ANETdetection, |
| save_checkpoint, make_optimizer, make_scheduler, |
| fix_random_seed, ModelEma, display_python_performance, get_average_performance, merge_ResultSaveObj) |
| import itertools |
| import collections |
| from IPython import embed |
|
|
| def load_json(filename): |
| with open(filename, encoding='utf8') as fr: |
| return json.load(fr) |
|
|
| from terminaltables import AsciiTable |
|
|
| |
| def main(args): |
| """main function that handles training / inference""" |
|
|
| """1. setup parameters / folders""" |
| |
| args.start_epoch = 0 |
| if os.path.isfile(args.config): |
| cfg = load_config(args.config) |
| else: |
| raise ValueError("Config file does not exist.") |
| |
|
|
| |
| tb_writer = None |
|
|
| |
| rng_generator = fix_random_seed(cfg['init_rand_seed'], include_cuda=True) |
|
|
| |
| cfg['opt']["learning_rate"] *= len(cfg['devices']) |
| cfg['loader']['num_workers'] *= len(cfg['devices']) |
|
|
| cfg['dataset']['max_seq_len'] = cfg['dataset']['num_frames'] |
| cfg['save_root'] = os.path.join('model_ckpt') |
| """2. create dataset / dataloader""" |
| train_dataset = make_dataset( |
| cfg['dataset_name'], True, cfg['train_split_list'], **cfg['dataset'] |
| ) |
| |
| |
| cfg['model']['train_cfg']['head_empty_cls'] = [] |
|
|
| |
| train_loader = make_data_loader( |
| train_dataset, True, rng_generator, **cfg['loader']) |
| """2. create dataset / dataloader""" |
| val_dataset_list = [] |
| val_loader_list = [] |
|
|
| for val_split in cfg['val_split_list']: |
| val_dataset = make_dataset( |
| cfg['dataset_name'], False, val_split, **cfg['dataset'] |
| ) |
| val_loader = make_data_loader( |
| val_dataset, False, None, 1, cfg['loader']['num_workers'] |
| ) |
| val_dataset_list.append(val_dataset) |
| val_loader_list.append(val_loader) |
|
|
| """3. create model, optimizer, and scheduler""" |
| |
| model = make_meta_arch(cfg['model_name'], **cfg['model']) |
| |
| |
| |
| model = nn.DataParallel(model, device_ids=cfg['devices']) |
| |
| optimizer = make_optimizer(model, cfg['opt']) |
| |
| num_iters_per_epoch = len(train_loader) |
| scheduler = make_scheduler(optimizer, cfg['opt'], num_iters_per_epoch) |
|
|
| """4. Resume from model / Misc""" |
|
|
| args.print_freq = 100 |
| det_eval, output_file = None, None |
| """5. Test the model""" |
|
|
| """4. training / validation loop""" |
| print("\nStart training model {:s} ...".format(cfg['model_name'])) |
|
|
| |
| max_epochs = cfg['opt'].get( |
| 'early_stop_epochs', |
| cfg['opt']['epochs'] + cfg['opt']['warmup_epochs'] |
| ) |
|
|
| model_ema = None |
| new_best_per_split = None |
| cfg['train_split'] = cfg['train_split_list'][0] |
| cfg['test_split_list'] = cfg['val_split_list'] |
| for epoch in range(args.start_epoch, max_epochs): |
| |
| args.print_freq = 50 |
| train_one_epoch( |
| train_loader, |
| model, |
| optimizer, |
| scheduler, |
| epoch, |
| model_ema = model_ema, |
| clip_grad_l2norm = cfg['train_cfg']['clip_grad_l2norm'], |
| tb_writer=tb_writer, |
| print_freq=args.print_freq |
| ) |
|
|
| if (epoch % cfg['opt']['valid_epoch'] != 0 or epoch < cfg['opt']['start_test_epoch']) and epoch != max_epochs - 1: |
| continue |
| args.print_freq = 2000 |
| print('=' * 100) |
| print(f'[Test]: Epoch {epoch} started') |
| print('=' * 100) |
| split_results_dict = {tmp_k: [] for tmp_k in cfg['val_split_list']} |
| split_results_obj_dict = {tmp_k: [] for tmp_k in cfg['val_split_list']} |
| for val_split, val_loader in zip(cfg['val_split_list'], val_loader_list): |
| split_output_file = output_file |
| _, acc_results, result_save_obj_dict = valid_one_epoch( |
| val_loader, |
| model, |
| -1, |
| evaluator=det_eval, |
| output_file=split_output_file, |
| ext_score_file=cfg['test_cfg']['ext_score_file'], |
| tb_writer=None, |
| print_freq=args.print_freq, |
| ) |
|
|
|
|
| |
| for local_weight in result_save_obj_dict: |
|
|
| val_results_obj = result_save_obj_dict[local_weight] |
| split_results_obj_dict[val_split].append(val_results_obj) |
|
|
|
|
| merge_keys = cfg['val_split_list'] |
| new_split_results_dict = collections.defaultdict(list) |
|
|
| in_domain = [tmp_itm.replace('train', 'test') for tmp_itm in cfg['train_split_list'] if 'real' not in tmp_itm] |
| out_domain_2 = [tmp_itm for tmp_itm in cfg['test_split_list'] if tmp_itm not in in_domain] |
|
|
| domain_name_list = ['in_domain', 'out_domain'] |
| tqdm_list = [in_domain, out_domain_2] |
|
|
| domain_name_id = -1 |
| start_add = len(split_results_obj_dict) |
| for merge_combo in tqdm_list: |
| domain_name_id += 1 |
| if len(merge_combo) <= 0: |
| continue |
| merge_key_name = "+".join(merge_combo) |
| merge_result_list = [] |
| |
| for merge_idx in range(len(split_results_obj_dict[merge_combo[0]])): |
| merge_objs = [split_results_obj_dict[k][merge_idx] for k in merge_combo] |
| merge_obj = merge_ResultSaveObj(merge_objs) |
| merge_result_list.append(merge_obj) |
| split_results_obj_dict[domain_name_list[domain_name_id]+f' ({merge_key_name})'] = merge_result_list |
|
|
|
|
| tqdm_list = tqdm(split_results_obj_dict.items()) |
| start_id = -1 |
| for merge_k, merge_v_list in tqdm_list: |
| start_id += 1 |
| if start_id < start_add and cfg['test_cfg']['skip_separate_flag']: |
| continue |
| for merge_v in merge_v_list: |
| new_split_results_dict[merge_k].append(merge_v.eval()) |
|
|
| |
| for test_split_key in new_split_results_dict: |
| if 'in_domain ' in test_split_key: |
| break |
| assert 'in_domain ' in test_split_key |
| for test_split_key_assist in new_split_results_dict: |
| if 'out_domain ' in test_split_key_assist: |
| break |
| assert 'out_domain ' in test_split_key_assist |
|
|
| if new_best_per_split is None: |
| new_best_per_split = { |
| val_split: { |
| "best_avg": float("-inf"), |
| "best_epoch": None, |
| "best_local_weight": None, |
| "best_results": None, |
| } |
| for val_split in new_split_results_dict |
| } |
|
|
| local_weight_list = list(result_save_obj_dict.keys()) |
| print('='*100) |
| num_train_samples = len(train_dataset) |
| print(f"Current Validation Results | Epoch {epoch} | Trained on {cfg['train_split_list']} ({num_train_samples} samples)") |
| print('=' * 100) |
| for merge_k, merge_v_list in new_split_results_dict.items(): |
| for merge_v, local_weight in zip(merge_v_list, local_weight_list): |
| avg_perf = get_average_performance(merge_v) |
| print(f"Results for {merge_k}: avg={avg_perf:.4f} | epoch {epoch} | local_weight {local_weight}") |
| print(display_python_performance(merge_v)) |
| if avg_perf > new_best_per_split[merge_k]["best_avg"]: |
| new_best_per_split[merge_k]["best_avg"] = avg_perf |
| new_best_per_split[merge_k]["best_epoch"] = epoch |
| new_best_per_split[merge_k]["best_local_weight"] = local_weight |
| new_best_per_split[merge_k]["best_results"] = merge_v |
| |
| print(f"Update best results") |
| print() |
|
|
| |
| num_train_samples = len(train_dataset) |
| print('='*100) |
| print(f"Best Validation Results | Epoch {epoch} | Trained on {cfg['train_split_list']} ({num_train_samples} samples)") |
| print('='*100) |
| for val_split in new_best_per_split: |
| rec = new_best_per_split[val_split] |
| print(f"Best for {val_split}:\nR1 = {rec['best_avg']:.4f}\nepoch {rec['best_epoch']} | local_weight {rec['best_local_weight']}") |
| print(display_python_performance(rec["best_results"])) |
| print() |
|
|
| |
| if __name__ == '__main__': |
| """Entry Point""" |
| |
| parser = argparse.ArgumentParser( |
| description='Train a point-based transformer for action localization') |
| parser.add_argument('--config', metavar='DIR', |
| help='path to a config file') |
| parser.add_argument('-p', '--print-freq', default=10, type=int, |
| help='print frequency (default: 10 iterations)') |
| parser.add_argument('-c', '--ckpt-freq', default=5, type=int, |
| help='checkpoint frequency (default: every 5 epochs)') |
| parser.add_argument('--output', default='', type=str, |
| help='name of exp folder (default: none)') |
| parser.add_argument('--resume', default='', type=str, metavar='PATH', |
| help='path to a checkpoint (default: none)') |
| parser.add_argument('--tag', default='baseline', type=str, help='experiment tag') |
| args = parser.parse_args() |
|
|
| main(args) |
|
|