vicca / VG /train.py
sayehghp's picture
Add application file
0f8411f
# Copyright (c) 2022 IDEA. All Rights Reserved.
# ------------------------------------------------------------------------
import argparse
import datetime
import json
import random
import time
from pathlib import Path
import os, sys
import numpy as np
import torch
from torch.utils.data import DataLoader, DistributedSampler
from util.get_param_dicts import get_param_dict
from util.logger import setup_logger
from util.slconfig import DictAction, SLConfig
from util.utils import BestMetricHolder
import util.misc as utils
import datasets
from datasets import build_dataset, get_coco_api_from_dataset
from engine import evaluate, train_one_epoch
from groundingdino.util.utils import clean_state_dict
sys.path.insert(1,"/home/gholipos/physionet.org/files/mimic-cxr-jpg/2.0.0/")
def get_args_parser():
parser = argparse.ArgumentParser('Set transformer detector', add_help=False)
parser.add_argument('--config_file', '-c', type=str, required=True)
parser.add_argument('--options',
nargs='+',
action=DictAction,
help='override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file.')
# dataset parameters
parser.add_argument("--datasets", type=str, required=True, help='path to datasets json')
parser.add_argument('--remove_difficult', action='store_true')
parser.add_argument('--fix_size', action='store_true')
# training parameters
parser.add_argument('--output_dir', default='',
help='path where to save, empty for no saving')
parser.add_argument('--note', default='',
help='add some notes to the experiment')
parser.add_argument('--device', default='cuda',
help='device to use for training / testing')
parser.add_argument('--seed', default=42, type=int)
parser.add_argument('--resume', default='', help='resume from checkpoint')
parser.add_argument('--pretrain_model_path', help='load from other checkpoint')
parser.add_argument('--finetune_ignore', type=str, nargs='+')
parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
help='start epoch')
parser.add_argument('--eval', action='store_true')
parser.add_argument('--num_workers', default=8, type=int)
parser.add_argument('--test', action='store_true')
parser.add_argument('--debug', action='store_true')
parser.add_argument('--find_unused_params', action='store_true')
parser.add_argument('--save_results', action='store_true')
parser.add_argument('--save_log', action='store_true')
# distributed training parameters
parser.add_argument('--world_size', default=1, type=int,
help='number of distributed processes')
parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
parser.add_argument('--rank', default=0, type=int,
help='number of distributed processes')
parser.add_argument("--local_rank", type=int, help='local rank for DistributedDataParallel')
parser.add_argument("--local-rank", type=int, help='local rank for DistributedDataParallel')
parser.add_argument('--amp', action='store_true',
help="Train with mixed precision")
return parser
def build_model_main(args):
# we use register to maintain models from catdet6 on.
from models.registry import MODULE_BUILD_FUNCS
assert args.modelname in MODULE_BUILD_FUNCS._module_dict
build_func = MODULE_BUILD_FUNCS.get(args.modelname)
model, criterion, postprocessors = build_func(args)
return model, criterion, postprocessors
def main(args):
utils.setup_distributed(args)
# load cfg file and update the args
print("Loading config file from {}".format(args.config_file))
time.sleep(args.rank * 0.02)
cfg = SLConfig.fromfile(args.config_file)
if args.options is not None:
cfg.merge_from_dict(args.options)
if args.rank == 0:
save_cfg_path = os.path.join(args.output_dir, "config_cfg.py")
cfg.dump(save_cfg_path)
save_json_path = os.path.join(args.output_dir, "config_args_raw.json")
with open(save_json_path, 'w') as f:
json.dump(vars(args), f, indent=2)
cfg_dict = cfg._cfg_dict.to_dict()
args_vars = vars(args)
for k,v in cfg_dict.items():
if k not in args_vars:
setattr(args, k, v)
else:
raise ValueError("Key {} can used by args only".format(k))
# update some new args temporally
if not getattr(args, 'debug', None):
args.debug = False
# setup logger
os.makedirs(args.output_dir, exist_ok=True)
logger = setup_logger(output=os.path.join(args.output_dir, 'info.txt'), distributed_rank=args.rank, color=False, name="detr")
logger.info("git:\n {}\n".format(utils.get_sha()))
logger.info("Command: "+' '.join(sys.argv))
if args.rank == 0:
save_json_path = os.path.join(args.output_dir, "config_args_all.json")
with open(save_json_path, 'w') as f:
json.dump(vars(args), f, indent=2)
logger.info("Full config saved to {}".format(save_json_path))
with open(args.datasets) as f:
dataset_meta = json.load(f)
if args.use_coco_eval:
args.coco_val_path = dataset_meta["val"][0]["anno"]
logger.info('world size: {}'.format(args.world_size))
logger.info('rank: {}'.format(args.rank))
logger.info('local_rank: {}'.format(args.local_rank))
logger.info("args: " + str(args) + '\n')
device = torch.device(args.device)
# fix the seed for reproducibility
seed = args.seed + utils.get_rank()
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
logger.debug("build model ... ...")
model, criterion, postprocessors = build_model_main(args)
wo_class_error = False
model.to(device)
logger.debug("build model, done.")
model_without_ddp = model
if args.distributed:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=args.find_unused_params)
model._set_static_graph()
model_without_ddp = model.module
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
logger.info('number of params:'+str(n_parameters))
logger.info("params before freezing:\n"+json.dumps({n: p.numel() for n, p in model.named_parameters() if p.requires_grad}, indent=2))
param_dicts = get_param_dict(args, model_without_ddp)
# freeze some layers
if args.freeze_keywords is not None:
for name, parameter in model.named_parameters():
for keyword in args.freeze_keywords:
if keyword in name:
parameter.requires_grad_(False)
break
logger.info("params after freezing:\n"+json.dumps({n: p.numel() for n, p in model.named_parameters() if p.requires_grad}, indent=2))
optimizer = torch.optim.AdamW(param_dicts, lr=args.lr,
weight_decay=args.weight_decay)
logger.debug("build dataset ... ...")
if not args.eval:
num_of_dataset_train = len(dataset_meta["train"])
if num_of_dataset_train == 1:
dataset_train = build_dataset(image_set='train', args=args, datasetinfo=dataset_meta["train"][0])
else:
from torch.utils.data import ConcatDataset
dataset_train_list = []
for idx in range(len(dataset_meta["train"])):
dataset_train_list.append(build_dataset(image_set='train', args=args, datasetinfo=dataset_meta["train"][idx]))
dataset_train = ConcatDataset(dataset_train_list)
logger.debug("build dataset, done.")
logger.debug(f'number of training dataset: {num_of_dataset_train}, samples: {len(dataset_train)}')
dataset_val = build_dataset(image_set='val', args=args, datasetinfo=dataset_meta["val"][0])
if args.distributed:
sampler_val = DistributedSampler(dataset_val, shuffle=False)
if not args.eval:
sampler_train = DistributedSampler(dataset_train)
else:
sampler_val = torch.utils.data.SequentialSampler(dataset_val)
if not args.eval:
sampler_train = torch.utils.data.RandomSampler(dataset_train)
if not args.eval:
batch_sampler_train = torch.utils.data.BatchSampler(
sampler_train, args.batch_size, drop_last=True)
data_loader_train = DataLoader(dataset_train, batch_sampler=batch_sampler_train,
collate_fn=utils.collate_fn, num_workers=args.num_workers)
data_loader_val = DataLoader(dataset_val, 4, sampler=sampler_val,
drop_last=False, collate_fn=utils.collate_fn, num_workers=args.num_workers)
if args.onecyclelr:
lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=args.lr, steps_per_epoch=len(data_loader_train), epochs=args.epochs, pct_start=0.2)
elif args.multi_step_lr:
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.lr_drop_list)
else:
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_drop)
base_ds = get_coco_api_from_dataset(dataset_val)
if args.frozen_weights is not None:
checkpoint = torch.load(args.frozen_weights, map_location='cpu')
model_without_ddp.detr.load_state_dict(clean_state_dict(checkpoint['model']),strict=False)
output_dir = Path(args.output_dir)
if os.path.exists(os.path.join(args.output_dir, 'checkpoint.pth')):
args.resume = os.path.join(args.output_dir, 'checkpoint.pth')
if args.resume:
if args.resume.startswith('https'):
checkpoint = torch.hub.load_state_dict_from_url(
args.resume, map_location='cpu', check_hash=True)
else:
checkpoint = torch.load(args.resume, map_location='cpu')
model_without_ddp.load_state_dict(clean_state_dict(checkpoint['model']),strict=False)
if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
optimizer.load_state_dict(checkpoint['optimizer'])
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
args.start_epoch = checkpoint['epoch'] + 1
if (not args.resume) and args.pretrain_model_path:
checkpoint = torch.load(args.pretrain_model_path, map_location='cpu')['model']
from collections import OrderedDict
_ignorekeywordlist = args.finetune_ignore if args.finetune_ignore else []
ignorelist = []
def check_keep(keyname, ignorekeywordlist):
for keyword in ignorekeywordlist:
if keyword in keyname:
ignorelist.append(keyname)
return False
return True
logger.info("Ignore keys: {}".format(json.dumps(ignorelist, indent=2)))
_tmp_st = OrderedDict({k:v for k, v in utils.clean_state_dict(checkpoint).items() if check_keep(k, _ignorekeywordlist)})
_load_output = model_without_ddp.load_state_dict(_tmp_st, strict=False)
logger.info(str(_load_output))
if args.eval:
os.environ['EVAL_FLAG'] = 'TRUE'
test_stats, coco_evaluator = evaluate(model, criterion, postprocessors,
data_loader_val, base_ds, device, args.output_dir, wo_class_error=wo_class_error, args=args)
if args.output_dir:
utils.save_on_master(coco_evaluator.coco_eval["bbox"].eval, output_dir / "eval.pth")
log_stats = {**{f'test_{k}': v for k, v in test_stats.items()} }
if args.output_dir and utils.is_main_process():
with (output_dir / "log.txt").open("a") as f:
f.write(json.dumps(log_stats) + "\n")
return
print("Start training")
start_time = time.time()
best_map_holder = BestMetricHolder(use_ema=False)
for epoch in range(args.start_epoch, args.epochs):
epoch_start_time = time.time()
if args.distributed:
sampler_train.set_epoch(epoch)
train_stats = train_one_epoch(
model, criterion, data_loader_train, optimizer, device, epoch,
args.clip_max_norm, wo_class_error=wo_class_error, lr_scheduler=lr_scheduler, args=args, logger=(logger if args.save_log else None))
if args.output_dir:
checkpoint_paths = [output_dir / 'checkpoint.pth']
if not args.onecyclelr:
lr_scheduler.step()
if args.output_dir:
checkpoint_paths = [output_dir / 'checkpoint.pth']
# extra checkpoint before LR drop and every 100 epochs
if (epoch + 1) % args.lr_drop == 0 or (epoch + 1) % args.save_checkpoint_interval == 0:
checkpoint_paths.append(output_dir / f'checkpoint{epoch:04}.pth')
for checkpoint_path in checkpoint_paths:
weights = {
'model': model_without_ddp.state_dict(),
'optimizer': optimizer.state_dict(),
'lr_scheduler': lr_scheduler.state_dict(),
'epoch': epoch,
'args': args,
}
utils.save_on_master(weights, checkpoint_path)
# eval
test_stats, coco_evaluator = evaluate(
model, criterion, postprocessors, data_loader_val, base_ds, device, args.output_dir,
wo_class_error=wo_class_error, args=args, logger=(logger if args.save_log else None)
)
map_regular = test_stats['coco_eval_bbox'][0]
_isbest = best_map_holder.update(map_regular, epoch, is_ema=False)
if _isbest:
checkpoint_path = output_dir / 'checkpoint_best_regular.pth'
utils.save_on_master({
'model': model_without_ddp.state_dict(),
'optimizer': optimizer.state_dict(),
'lr_scheduler': lr_scheduler.state_dict(),
'epoch': epoch,
'args': args,
}, checkpoint_path)
log_stats = {
**{f'train_{k}': v for k, v in train_stats.items()},
**{f'test_{k}': v for k, v in test_stats.items()},
}
try:
log_stats.update({'now_time': str(datetime.datetime.now())})
except:
pass
epoch_time = time.time() - epoch_start_time
epoch_time_str = str(datetime.timedelta(seconds=int(epoch_time)))
log_stats['epoch_time'] = epoch_time_str
if args.output_dir and utils.is_main_process():
with (output_dir / "log.txt").open("a") as f:
f.write(json.dumps(log_stats) + "\n")
# for evaluation logs
if coco_evaluator is not None:
(output_dir / 'eval').mkdir(exist_ok=True)
if "bbox" in coco_evaluator.coco_eval:
filenames = ['latest.pth']
if epoch % 50 == 0:
filenames.append(f'{epoch:03}.pth')
for name in filenames:
torch.save(coco_evaluator.coco_eval["bbox"].eval,
output_dir / "eval" / name)
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('Training time {}'.format(total_time_str))
# remove the copied files.
copyfilelist = vars(args).get('copyfilelist')
if copyfilelist and args.local_rank == 0:
from datasets.data_util import remove
for filename in copyfilelist:
print("Removing: {}".format(filename))
remove(filename)
if __name__ == '__main__':
parser = argparse.ArgumentParser('DETR training and evaluation script', parents=[get_args_parser()])
args = parser.parse_args()
if args.output_dir:
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
main(args)