import os import argparse from glob import glob import prettytable as pt from evaluation.metrics import evaluator from config import Config config = Config() def do_eval(args): # evaluation for whole dataset # dataset first in evaluation for _data_name in args.data_lst.split("+"): pred_data_dir = sorted( glob(os.path.join(args.pred_root, args.model_lst[0], _data_name)) ) if not pred_data_dir: print("Skip dataset {}.".format(_data_name)) continue gt_src = os.path.join(args.gt_root, _data_name) gt_paths = sorted(glob(os.path.join(gt_src, "gt", "*"))) print("#" * 20, _data_name, "#" * 20) filename = os.path.join(args.save_dir, "{}_eval.txt".format(_data_name)) tb = pt.PrettyTable() tb.vertical_char = "&" if config.task == "DIS5K": tb.field_names = [ "Dataset", "Method", "maxFm", "wFmeasure", "MAE", "Smeasure", "meanEm", "HCE", "maxEm", "meanFm", "adpEm", "adpFm", "mBA", "maxBIoU", "meanBIoU", ] elif config.task == "COD": tb.field_names = [ "Dataset", "Method", "Smeasure", "wFmeasure", "meanFm", "meanEm", "maxEm", "MAE", "maxFm", "adpEm", "adpFm", "HCE", "mBA", "maxBIoU", "meanBIoU", ] elif config.task == "HRSOD": tb.field_names = [ "Dataset", "Method", "Smeasure", "maxFm", "meanEm", "MAE", "maxEm", "meanFm", "wFmeasure", "adpEm", "adpFm", "HCE", "mBA", "maxBIoU", "meanBIoU", ] elif config.task == "General": tb.field_names = [ "Dataset", "Method", "maxFm", "wFmeasure", "MAE", "Smeasure", "meanEm", "HCE", "maxEm", "meanFm", "adpEm", "adpFm", "mBA", "maxBIoU", "meanBIoU", ] elif config.task == "General-2K": tb.field_names = [ "Dataset", "Method", "maxFm", "wFmeasure", "MAE", "Smeasure", "meanEm", "HCE", "maxEm", "meanFm", "adpEm", "adpFm", "mBA", "maxBIoU", "meanBIoU", ] elif config.task == "Matting": tb.field_names = [ "Dataset", "Method", "Smeasure", "maxFm", "meanEm", "MSE", "maxEm", "meanFm", "wFmeasure", "adpEm", "adpFm", "HCE", "mBA", "maxBIoU", "meanBIoU", ] else: tb.field_names = [ "Dataset", "Method", "Smeasure", "MAE", "maxEm", "meanEm", "maxFm", "meanFm", "wFmeasure", "adpEm", "adpFm", "HCE", "mBA", "maxBIoU", "meanBIoU", ] for _model_name in args.model_lst[:]: print("\t", "Evaluating model: {}...".format(_model_name)) pred_paths = [ p.replace( args.gt_root, os.path.join(args.pred_root, _model_name) ).replace("/gt/", "/") for p in gt_paths ] # print(pred_paths[:1], gt_paths[:1]) em, sm, fm, mae, mse, wfm, hce, mba, biou = evaluator( gt_paths=gt_paths, pred_paths=pred_paths, metrics=args.metrics.split("+"), verbose=config.verbose_eval, ) if config.task == "DIS5K": scores = [ fm["curve"].max().round(3), wfm.round(3), mae.round(3), sm.round(3), em["curve"].mean().round(3), int(hce.round()), em["curve"].max().round(3), fm["curve"].mean().round(3), em["adp"].round(3), fm["adp"].round(3), mba.round(3), biou["curve"].max().round(3), biou["curve"].mean().round(3), ] elif config.task == "COD": scores = [ sm.round(3), wfm.round(3), fm["curve"].mean().round(3), em["curve"].mean().round(3), em["curve"].max().round(3), mae.round(3), fm["curve"].max().round(3), em["adp"].round(3), fm["adp"].round(3), int(hce.round()), mba.round(3), biou["curve"].max().round(3), biou["curve"].mean().round(3), ] elif config.task == "HRSOD": scores = [ sm.round(3), fm["curve"].max().round(3), em["curve"].mean().round(3), mae.round(3), em["curve"].max().round(3), fm["curve"].mean().round(3), wfm.round(3), em["adp"].round(3), fm["adp"].round(3), int(hce.round()), mba.round(3), biou["curve"].max().round(3), biou["curve"].mean().round(3), ] elif config.task == "General": scores = [ fm["curve"].max().round(3), wfm.round(3), mae.round(3), sm.round(3), em["curve"].mean().round(3), int(hce.round()), em["curve"].max().round(3), fm["curve"].mean().round(3), em["adp"].round(3), fm["adp"].round(3), mba.round(3), biou["curve"].max().round(3), biou["curve"].mean().round(3), ] elif config.task == "General-2K": scores = [ fm["curve"].max().round(3), wfm.round(3), mae.round(3), sm.round(3), em["curve"].mean().round(3), int(hce.round()), em["curve"].max().round(3), fm["curve"].mean().round(3), em["adp"].round(3), fm["adp"].round(3), mba.round(3), biou["curve"].max().round(3), biou["curve"].mean().round(3), ] elif config.task == "Matting": scores = [ sm.round(3), fm["curve"].max().round(3), em["curve"].mean().round(3), mse.round(5), em["curve"].max().round(3), fm["curve"].mean().round(3), wfm.round(3), em["adp"].round(3), fm["adp"].round(3), int(hce.round()), mba.round(3), biou["curve"].max().round(3), biou["curve"].mean().round(3), ] else: scores = [ sm.round(3), mae.round(3), em["curve"].max().round(3), em["curve"].mean().round(3), fm["curve"].max().round(3), fm["curve"].mean().round(3), wfm.round(3), em["adp"].round(3), fm["adp"].round(3), int(hce.round()), mba.round(3), biou["curve"].max().round(3), biou["curve"].mean().round(3), ] for idx_score, score in enumerate(scores): scores[idx_score] = ( "." + format(score, ".3f").split(".")[-1] if score <= 1 else format(score, "<4") ) records = [_data_name, _model_name] + scores tb.add_row(records) # Write results after every check. with open(filename, "w+") as file_to_write: file_to_write.write(str(tb) + "\n") print(tb) if __name__ == "__main__": # set parameters parser = argparse.ArgumentParser() parser.add_argument( "--gt_root", type=str, help="ground-truth root", default=os.path.join(config.data_root_dir, config.task), ) parser.add_argument( "--pred_root", type=str, help="prediction root", default="./e_preds" ) parser.add_argument( "--data_lst", type=str, help="test dataset", default=config.testsets.replace(",", "+"), ) parser.add_argument( "--save_dir", type=str, help="candidate competitors", default="e_results" ) parser.add_argument( "--check_integrity", type=bool, help="whether to check the file integrity", default=False, ) parser.add_argument( "--metrics", type=str, help="candidate competitors", default="+".join( ["S", "MAE", "E", "F", "WF", "MBA", "BIoU", "MSE", "HCE"][ : 100 if "DIS5K" in config.task else -1 ] ), ) args = parser.parse_args() args.metrics = "+".join( ["S", "MAE", "E", "F", "WF", "MBA", "BIoU", "MSE", "HCE"][ : ( 100 if sum(["DIS-" in _data for _data in args.data_lst.split("+")]) else -1 ) ] ) os.makedirs(args.save_dir, exist_ok=True) try: args.model_lst = [ m for m in sorted( os.listdir(args.pred_root), key=lambda x: int(x.split("epoch_")[-1]), reverse=True, ) if int(m.split("epoch_")[-1]) % 1 == 0 ] except: args.model_lst = [m for m in sorted(os.listdir(args.pred_root))] # check the integrity of each candidates if args.check_integrity: for _data_name in args.data_lst.split("+"): for _model_name in args.model_lst: gt_pth = os.path.join(args.gt_root, _data_name) pred_pth = os.path.join(args.pred_root, _model_name, _data_name) if not sorted(os.listdir(gt_pth)) == sorted(os.listdir(pred_pth)): print( len(sorted(os.listdir(gt_pth))), len(sorted(os.listdir(pred_pth))), ) print( "The {} Dataset of {} Model is not matching to the ground-truth".format( _data_name, _model_name ) ) else: print(">>> skip check the integrity of each candidates") # start engine do_eval(args)