|
|
from loguru import logger
|
|
|
|
|
|
import torch
|
|
|
import torch.backends.cudnn as cudnn
|
|
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
|
|
|
|
from yolox.core import launch
|
|
|
from yolox.exp import get_exp
|
|
|
from yolox.utils import configure_nccl, fuse_model, get_local_rank, get_model_info, setup_logger
|
|
|
from yolox.evaluators import MOTEvaluatorDance as MOTEvaluator
|
|
|
|
|
|
from utils.args import make_parser, args_merge_params_form_exp
|
|
|
import os
|
|
|
import random
|
|
|
import warnings
|
|
|
import glob
|
|
|
import motmetrics as mm
|
|
|
from collections import OrderedDict
|
|
|
from pathlib import Path
|
|
|
|
|
|
|
|
|
def compare_dataframes(gts, ts):
|
|
|
accs = []
|
|
|
names = []
|
|
|
for k, tsacc in ts.items():
|
|
|
if k in gts:
|
|
|
logger.info('Comparing {}...'.format(k))
|
|
|
accs.append(mm.utils.compare_to_groundtruth(gts[k], tsacc, 'iou', distth=0.5))
|
|
|
names.append(k)
|
|
|
else:
|
|
|
logger.warning('No ground truth for {}, skipping.'.format(k))
|
|
|
|
|
|
return accs, names
|
|
|
|
|
|
|
|
|
@logger.catch
|
|
|
def main(exp, args, num_gpu):
|
|
|
|
|
|
if args.seed is not None:
|
|
|
random.seed(args.seed)
|
|
|
torch.manual_seed(args.seed)
|
|
|
cudnn.deterministic = True
|
|
|
warnings.warn(
|
|
|
"You have chosen to seed testing. This will turn on the CUDNN deterministic setting, "
|
|
|
)
|
|
|
|
|
|
is_distributed = num_gpu > 1
|
|
|
|
|
|
|
|
|
cudnn.benchmark = True
|
|
|
rank = args.local_rank
|
|
|
file_name = os.path.join(exp.output_dir, args.expn)
|
|
|
if rank == 0:
|
|
|
os.makedirs(file_name, exist_ok=True)
|
|
|
|
|
|
result_dir = "{}_test".format(args.expn) + \
|
|
|
"_EGWeightHigh" + str(args.EG_weight_high_score) + \
|
|
|
"_EGWeightLow" + str(args.EG_weight_low_score) + \
|
|
|
"_WithLongTermReIDCorrection" + str(args.with_longterm_reid_correction) + \
|
|
|
"_LongTermReIDCorrectionThresh" + str(args.longterm_reid_correction_thresh) + \
|
|
|
"_LongTermReIDCorrectionThreshLow" + str(args.longterm_reid_correction_thresh_low) + \
|
|
|
"_IoUThresh" + str(args.iou_thresh) + \
|
|
|
"_ScoreDifInterval" + str(args.TCM_first_step_weight) + \
|
|
|
"_SecScoreDifInterval" + str(args.TCM_byte_step_weight) \
|
|
|
if args.test else \
|
|
|
"{}_val".format(args.expn) + \
|
|
|
"_EGWeightHigh" + str(args.EG_weight_high_score) + \
|
|
|
"_EGWeightLow" + str(args.EG_weight_low_score) + \
|
|
|
"_WithLongTermReIDCorrection" + str(args.with_longterm_reid_correction) + \
|
|
|
"_LongTermReIDCorrectionThresh" + str(args.longterm_reid_correction_thresh) + \
|
|
|
"_LongTermReIDCorrectionThreshLow" + str(args.longterm_reid_correction_thresh_low) + \
|
|
|
"_IoUThresh" + str(args.iou_thresh) + \
|
|
|
"_ScoreDifInterval" + str(args.TCM_first_step_weight) + \
|
|
|
"_SecScoreDifInterval" + str(args.TCM_byte_step_weight)
|
|
|
results_folder = os.path.join(file_name, result_dir)
|
|
|
os.makedirs(results_folder, exist_ok=True)
|
|
|
setup_logger(file_name, distributed_rank=rank, filename="val_log.txt", mode="a")
|
|
|
logger.info("Args: {}".format(args))
|
|
|
|
|
|
if args.conf is not None:
|
|
|
exp.test_conf = args.conf
|
|
|
if args.nms is not None:
|
|
|
exp.nmsthre = args.nms
|
|
|
if args.tsize is not None:
|
|
|
exp.test_size = (args.tsize, args.tsize)
|
|
|
|
|
|
model = exp.get_model()
|
|
|
logger.info("Model Summary: {}".format(get_model_info(model, exp.test_size)))
|
|
|
|
|
|
val_loader = exp.get_eval_loader(args.batch_size, is_distributed, args.test, run_tracking=True)
|
|
|
evaluator = MOTEvaluator(
|
|
|
args=args,
|
|
|
dataloader=val_loader,
|
|
|
img_size=exp.test_size,
|
|
|
confthre=exp.test_conf,
|
|
|
nmsthre=exp.nmsthre,
|
|
|
num_classes=exp.num_classes,
|
|
|
)
|
|
|
|
|
|
torch.cuda.set_device(rank)
|
|
|
model.cuda(rank)
|
|
|
model.eval()
|
|
|
|
|
|
if not args.speed and not args.trt:
|
|
|
if args.ckpt is None:
|
|
|
ckpt_file = os.path.join(file_name, "best_ckpt.pth.tar")
|
|
|
else:
|
|
|
ckpt_file = args.ckpt
|
|
|
logger.info("loading checkpoint")
|
|
|
loc = "cuda:{}".format(rank)
|
|
|
ckpt = torch.load(ckpt_file, map_location=loc)
|
|
|
|
|
|
model.load_state_dict(ckpt["model"])
|
|
|
logger.info("loaded checkpoint done.")
|
|
|
|
|
|
if is_distributed:
|
|
|
model = DDP(model, device_ids=[rank])
|
|
|
|
|
|
if args.fuse:
|
|
|
logger.info("\tFusing model...")
|
|
|
model = fuse_model(model)
|
|
|
|
|
|
if args.trt:
|
|
|
assert (
|
|
|
not args.fuse and not is_distributed and args.batch_size == 1
|
|
|
), "TensorRT model is not support model fusing and distributed inferencing!"
|
|
|
trt_file = os.path.join(file_name, "model_trt.pth")
|
|
|
assert os.path.exists(
|
|
|
trt_file
|
|
|
), "TensorRT model is not found!\n Run tools/trt.py first!"
|
|
|
model.head.decode_in_inference = False
|
|
|
decoder = model.head.decode_outputs
|
|
|
else:
|
|
|
trt_file = None
|
|
|
decoder = None
|
|
|
|
|
|
|
|
|
if not args.with_reid:
|
|
|
*_, summary = evaluator.evaluate_su_t(
|
|
|
args, model, is_distributed, args.fp16, trt_file, decoder, exp.test_size, results_folder
|
|
|
)
|
|
|
else:
|
|
|
*_, summary = evaluator.evaluate_su_t_reid(
|
|
|
args, model, is_distributed, args.fp16, trt_file, decoder, exp.test_size, results_folder
|
|
|
)
|
|
|
|
|
|
|
|
|
if args.test:
|
|
|
|
|
|
return
|
|
|
|
|
|
logger.info("\n" + summary)
|
|
|
|
|
|
logger.info('Completed')
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
args = make_parser().parse_args()
|
|
|
exp = get_exp(args.exp_file, args.name)
|
|
|
exp.merge(args.opts)
|
|
|
args_merge_params_form_exp(args, exp)
|
|
|
|
|
|
if not args.expn:
|
|
|
args.expn = exp.exp_name
|
|
|
num_gpu = torch.cuda.device_count() if args.devices is None else args.devices
|
|
|
assert num_gpu <= torch.cuda.device_count()
|
|
|
|
|
|
launch(
|
|
|
main,
|
|
|
num_gpu,
|
|
|
args.num_machines,
|
|
|
args.machine_rank,
|
|
|
backend=args.dist_backend,
|
|
|
dist_url=args.dist_url,
|
|
|
args=(exp, args, num_gpu),
|
|
|
) |