forensics-grpo / code /train.py
sdzt's picture
Add source code
33569f9 verified
Raw
History Blame Contribute Delete
10.6 kB
# python imports
import argparse
import os
import time
import datetime
import yaml
import json
from pprint import pprint
# torch imports
import torch
import torch.nn as nn
import torch.utils.data
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm, trange
# for visualization
# from torch.utils.tensorboard import SummaryWriter
# our code
from libs.core import load_config
from libs.datasets import make_dataset, make_data_loader
from libs.modeling import make_meta_arch
from libs.utils import (train_one_epoch, valid_one_epoch, ANETdetection,
save_checkpoint, make_optimizer, make_scheduler,
fix_random_seed, ModelEma, display_python_performance, get_average_performance, merge_ResultSaveObj)
import itertools
import collections
from IPython import embed
def load_json(filename):
with open(filename, encoding='utf8') as fr:
return json.load(fr)
from terminaltables import AsciiTable
################################################################################
def main(args):
"""main function that handles training / inference"""
"""1. setup parameters / folders"""
# parse args
args.start_epoch = 0
if os.path.isfile(args.config):
cfg = load_config(args.config)
else:
raise ValueError("Config file does not exist.")
# pprint(cfg)
# tensorboard writer
tb_writer = None
# fix the random seeds (this will fix everything)
rng_generator = fix_random_seed(cfg['init_rand_seed'], include_cuda=True)
# re-scale learning rate / # workers based on number of GPUs
cfg['opt']["learning_rate"] *= len(cfg['devices'])
cfg['loader']['num_workers'] *= len(cfg['devices'])
cfg['dataset']['max_seq_len'] = cfg['dataset']['num_frames']
cfg['save_root'] = os.path.join('model_ckpt')
"""2. create dataset / dataloader"""
train_dataset = make_dataset(
cfg['dataset_name'], True, cfg['train_split_list'], **cfg['dataset']
)
# update cfg based on dataset attributes (fix to epic-kitchens)
# train_db_vars = train_dataset.get_attributes()
cfg['model']['train_cfg']['head_empty_cls'] = []
# data loaders
train_loader = make_data_loader(
train_dataset, True, rng_generator, **cfg['loader'])
"""2. create dataset / dataloader"""
val_dataset_list = []
val_loader_list = []
for val_split in cfg['val_split_list']:
val_dataset = make_dataset(
cfg['dataset_name'], False, val_split, **cfg['dataset']
)
val_loader = make_data_loader(
val_dataset, False, None, 1, cfg['loader']['num_workers']
)
val_dataset_list.append(val_dataset)
val_loader_list.append(val_loader)
"""3. create model, optimizer, and scheduler"""
# model
model = make_meta_arch(cfg['model_name'], **cfg['model'])
# model.load_state_dict(torch.load(os.path.join(cfg['save_root'], '001.pth')))
# embed()
# not ideal for multi GPU training, ok for now
model = nn.DataParallel(model, device_ids=cfg['devices'])
# optimizer
optimizer = make_optimizer(model, cfg['opt'])
# schedule
num_iters_per_epoch = len(train_loader)
scheduler = make_scheduler(optimizer, cfg['opt'], num_iters_per_epoch)
"""4. Resume from model / Misc"""
args.print_freq = 100
det_eval, output_file = None, None
"""5. Test the model"""
"""4. training / validation loop"""
print("\nStart training model {:s} ...".format(cfg['model_name']))
# start training
max_epochs = cfg['opt'].get(
'early_stop_epochs',
cfg['opt']['epochs'] + cfg['opt']['warmup_epochs']
)
model_ema = None
new_best_per_split = None
cfg['train_split'] = cfg['train_split_list'][0]
cfg['test_split_list'] = cfg['val_split_list']
for epoch in range(args.start_epoch, max_epochs):
# train for one epoch
args.print_freq = 50
train_one_epoch(
train_loader,
model,
optimizer,
scheduler,
epoch,
model_ema = model_ema,
clip_grad_l2norm = cfg['train_cfg']['clip_grad_l2norm'],
tb_writer=tb_writer,
print_freq=args.print_freq
)
if (epoch % cfg['opt']['valid_epoch'] != 0 or epoch < cfg['opt']['start_test_epoch']) and epoch != max_epochs - 1:
continue
args.print_freq = 2000
print('=' * 100)
print(f'[Test]: Epoch {epoch} started')
print('=' * 100)
split_results_dict = {tmp_k: [] for tmp_k in cfg['val_split_list']}
split_results_obj_dict = {tmp_k: [] for tmp_k in cfg['val_split_list']}
for val_split, val_loader in zip(cfg['val_split_list'], val_loader_list):
split_output_file = output_file
_, acc_results, result_save_obj_dict = valid_one_epoch(
val_loader,
model,
-1,
evaluator=det_eval,
output_file=split_output_file,
ext_score_file=cfg['test_cfg']['ext_score_file'],
tb_writer=None,
print_freq=args.print_freq,
)
# 计算平均性能指标
for local_weight in result_save_obj_dict:
val_results_obj = result_save_obj_dict[local_weight]
split_results_obj_dict[val_split].append(val_results_obj)
merge_keys = cfg['val_split_list']
new_split_results_dict = collections.defaultdict(list)
in_domain = [tmp_itm.replace('train', 'test') for tmp_itm in cfg['train_split_list'] if 'real' not in tmp_itm]
out_domain_2 = [tmp_itm for tmp_itm in cfg['test_split_list'] if tmp_itm not in in_domain]
domain_name_list = ['in_domain', 'out_domain']
tqdm_list = [in_domain, out_domain_2]
domain_name_id = -1
start_add = len(split_results_obj_dict)
for merge_combo in tqdm_list:
domain_name_id += 1
if len(merge_combo) <= 0:
continue
merge_key_name = "+".join(merge_combo)
merge_result_list = []
# embed()
for merge_idx in range(len(split_results_obj_dict[merge_combo[0]])):
merge_objs = [split_results_obj_dict[k][merge_idx] for k in merge_combo]
merge_obj = merge_ResultSaveObj(merge_objs)
merge_result_list.append(merge_obj)
split_results_obj_dict[domain_name_list[domain_name_id]+f' ({merge_key_name})'] = merge_result_list
tqdm_list = tqdm(split_results_obj_dict.items())
start_id = -1
for merge_k, merge_v_list in tqdm_list:
start_id += 1
if start_id < start_add and cfg['test_cfg']['skip_separate_flag']:
continue
for merge_v in merge_v_list:
new_split_results_dict[merge_k].append(merge_v.eval())
# embed()
for test_split_key in new_split_results_dict:
if 'in_domain ' in test_split_key:
break
assert 'in_domain ' in test_split_key
for test_split_key_assist in new_split_results_dict:
if 'out_domain ' in test_split_key_assist:
break
assert 'out_domain ' in test_split_key_assist
if new_best_per_split is None:
new_best_per_split = {
val_split: {
"best_avg": float("-inf"),
"best_epoch": None,
"best_local_weight": None,
"best_results": None,
}
for val_split in new_split_results_dict
}
local_weight_list = list(result_save_obj_dict.keys())
print('='*100)
num_train_samples = len(train_dataset)
print(f"Current Validation Results | Epoch {epoch} | Trained on {cfg['train_split_list']} ({num_train_samples} samples)")
print('=' * 100)
for merge_k, merge_v_list in new_split_results_dict.items():
for merge_v, local_weight in zip(merge_v_list, local_weight_list):
avg_perf = get_average_performance(merge_v)
print(f"Results for {merge_k}: avg={avg_perf:.4f} | epoch {epoch} | local_weight {local_weight}")
print(display_python_performance(merge_v))
if avg_perf > new_best_per_split[merge_k]["best_avg"]:
new_best_per_split[merge_k]["best_avg"] = avg_perf
new_best_per_split[merge_k]["best_epoch"] = epoch
new_best_per_split[merge_k]["best_local_weight"] = local_weight
new_best_per_split[merge_k]["best_results"] = merge_v
# print(f"Update best results for {merge_k}: avg={avg_perf:.4f} | epoch {epoch} | local_weight {local_weight}")
print(f"Update best results")
print()
# print('='*100)
num_train_samples = len(train_dataset)
print('='*100)
print(f"Best Validation Results | Epoch {epoch} | Trained on {cfg['train_split_list']} ({num_train_samples} samples)")
print('='*100)
for val_split in new_best_per_split:
rec = new_best_per_split[val_split]
print(f"Best for {val_split}:\nR1 = {rec['best_avg']:.4f}\nepoch {rec['best_epoch']} | local_weight {rec['best_local_weight']}")
print(display_python_performance(rec["best_results"]))
print()
################################################################################
if __name__ == '__main__':
"""Entry Point"""
# the arg parser
parser = argparse.ArgumentParser(
description='Train a point-based transformer for action localization')
parser.add_argument('--config', metavar='DIR',
help='path to a config file')
parser.add_argument('-p', '--print-freq', default=10, type=int,
help='print frequency (default: 10 iterations)')
parser.add_argument('-c', '--ckpt-freq', default=5, type=int,
help='checkpoint frequency (default: every 5 epochs)')
parser.add_argument('--output', default='', type=str,
help='name of exp folder (default: none)')
parser.add_argument('--resume', default='', type=str, metavar='PATH',
help='path to a checkpoint (default: none)')
parser.add_argument('--tag', default='baseline', type=str, help='experiment tag')
args = parser.parse_args()
main(args)