| | import os |
| | import argparse |
| | from glob import glob |
| | import prettytable as pt |
| |
|
| | from evaluation.metrics import evaluator |
| | from config import Config |
| |
|
| |
|
| | config = Config() |
| |
|
| |
|
| | def do_eval(args): |
| | |
| | |
| | for _data_name in args.data_lst.split("+"): |
| | pred_data_dir = sorted( |
| | glob(os.path.join(args.pred_root, args.model_lst[0], _data_name)) |
| | ) |
| | if not pred_data_dir: |
| | print("Skip dataset {}.".format(_data_name)) |
| | continue |
| | gt_src = os.path.join(args.gt_root, _data_name) |
| | gt_paths = sorted(glob(os.path.join(gt_src, "gt", "*"))) |
| | print("#" * 20, _data_name, "#" * 20) |
| | filename = os.path.join(args.save_dir, "{}_eval.txt".format(_data_name)) |
| | tb = pt.PrettyTable() |
| | tb.vertical_char = "&" |
| | if config.task == "DIS5K": |
| | tb.field_names = [ |
| | "Dataset", |
| | "Method", |
| | "maxFm", |
| | "wFmeasure", |
| | "MAE", |
| | "Smeasure", |
| | "meanEm", |
| | "HCE", |
| | "maxEm", |
| | "meanFm", |
| | "adpEm", |
| | "adpFm", |
| | "mBA", |
| | "maxBIoU", |
| | "meanBIoU", |
| | ] |
| | elif config.task == "COD": |
| | tb.field_names = [ |
| | "Dataset", |
| | "Method", |
| | "Smeasure", |
| | "wFmeasure", |
| | "meanFm", |
| | "meanEm", |
| | "maxEm", |
| | "MAE", |
| | "maxFm", |
| | "adpEm", |
| | "adpFm", |
| | "HCE", |
| | "mBA", |
| | "maxBIoU", |
| | "meanBIoU", |
| | ] |
| | elif config.task == "HRSOD": |
| | tb.field_names = [ |
| | "Dataset", |
| | "Method", |
| | "Smeasure", |
| | "maxFm", |
| | "meanEm", |
| | "MAE", |
| | "maxEm", |
| | "meanFm", |
| | "wFmeasure", |
| | "adpEm", |
| | "adpFm", |
| | "HCE", |
| | "mBA", |
| | "maxBIoU", |
| | "meanBIoU", |
| | ] |
| | elif config.task == "General": |
| | tb.field_names = [ |
| | "Dataset", |
| | "Method", |
| | "maxFm", |
| | "wFmeasure", |
| | "MAE", |
| | "Smeasure", |
| | "meanEm", |
| | "HCE", |
| | "maxEm", |
| | "meanFm", |
| | "adpEm", |
| | "adpFm", |
| | "mBA", |
| | "maxBIoU", |
| | "meanBIoU", |
| | ] |
| | elif config.task == "General-2K": |
| | tb.field_names = [ |
| | "Dataset", |
| | "Method", |
| | "maxFm", |
| | "wFmeasure", |
| | "MAE", |
| | "Smeasure", |
| | "meanEm", |
| | "HCE", |
| | "maxEm", |
| | "meanFm", |
| | "adpEm", |
| | "adpFm", |
| | "mBA", |
| | "maxBIoU", |
| | "meanBIoU", |
| | ] |
| | elif config.task == "Matting": |
| | tb.field_names = [ |
| | "Dataset", |
| | "Method", |
| | "Smeasure", |
| | "maxFm", |
| | "meanEm", |
| | "MSE", |
| | "maxEm", |
| | "meanFm", |
| | "wFmeasure", |
| | "adpEm", |
| | "adpFm", |
| | "HCE", |
| | "mBA", |
| | "maxBIoU", |
| | "meanBIoU", |
| | ] |
| | else: |
| | tb.field_names = [ |
| | "Dataset", |
| | "Method", |
| | "Smeasure", |
| | "MAE", |
| | "maxEm", |
| | "meanEm", |
| | "maxFm", |
| | "meanFm", |
| | "wFmeasure", |
| | "adpEm", |
| | "adpFm", |
| | "HCE", |
| | "mBA", |
| | "maxBIoU", |
| | "meanBIoU", |
| | ] |
| | for _model_name in args.model_lst[:]: |
| | print("\t", "Evaluating model: {}...".format(_model_name)) |
| | pred_paths = [ |
| | p.replace( |
| | args.gt_root, os.path.join(args.pred_root, _model_name) |
| | ).replace("/gt/", "/") |
| | for p in gt_paths |
| | ] |
| | |
| | em, sm, fm, mae, mse, wfm, hce, mba, biou = evaluator( |
| | gt_paths=gt_paths, |
| | pred_paths=pred_paths, |
| | metrics=args.metrics.split("+"), |
| | verbose=config.verbose_eval, |
| | ) |
| | if config.task == "DIS5K": |
| | scores = [ |
| | fm["curve"].max().round(3), |
| | wfm.round(3), |
| | mae.round(3), |
| | sm.round(3), |
| | em["curve"].mean().round(3), |
| | int(hce.round()), |
| | em["curve"].max().round(3), |
| | fm["curve"].mean().round(3), |
| | em["adp"].round(3), |
| | fm["adp"].round(3), |
| | mba.round(3), |
| | biou["curve"].max().round(3), |
| | biou["curve"].mean().round(3), |
| | ] |
| | elif config.task == "COD": |
| | scores = [ |
| | sm.round(3), |
| | wfm.round(3), |
| | fm["curve"].mean().round(3), |
| | em["curve"].mean().round(3), |
| | em["curve"].max().round(3), |
| | mae.round(3), |
| | fm["curve"].max().round(3), |
| | em["adp"].round(3), |
| | fm["adp"].round(3), |
| | int(hce.round()), |
| | mba.round(3), |
| | biou["curve"].max().round(3), |
| | biou["curve"].mean().round(3), |
| | ] |
| | elif config.task == "HRSOD": |
| | scores = [ |
| | sm.round(3), |
| | fm["curve"].max().round(3), |
| | em["curve"].mean().round(3), |
| | mae.round(3), |
| | em["curve"].max().round(3), |
| | fm["curve"].mean().round(3), |
| | wfm.round(3), |
| | em["adp"].round(3), |
| | fm["adp"].round(3), |
| | int(hce.round()), |
| | mba.round(3), |
| | biou["curve"].max().round(3), |
| | biou["curve"].mean().round(3), |
| | ] |
| | elif config.task == "General": |
| | scores = [ |
| | fm["curve"].max().round(3), |
| | wfm.round(3), |
| | mae.round(3), |
| | sm.round(3), |
| | em["curve"].mean().round(3), |
| | int(hce.round()), |
| | em["curve"].max().round(3), |
| | fm["curve"].mean().round(3), |
| | em["adp"].round(3), |
| | fm["adp"].round(3), |
| | mba.round(3), |
| | biou["curve"].max().round(3), |
| | biou["curve"].mean().round(3), |
| | ] |
| | elif config.task == "General-2K": |
| | scores = [ |
| | fm["curve"].max().round(3), |
| | wfm.round(3), |
| | mae.round(3), |
| | sm.round(3), |
| | em["curve"].mean().round(3), |
| | int(hce.round()), |
| | em["curve"].max().round(3), |
| | fm["curve"].mean().round(3), |
| | em["adp"].round(3), |
| | fm["adp"].round(3), |
| | mba.round(3), |
| | biou["curve"].max().round(3), |
| | biou["curve"].mean().round(3), |
| | ] |
| | elif config.task == "Matting": |
| | scores = [ |
| | sm.round(3), |
| | fm["curve"].max().round(3), |
| | em["curve"].mean().round(3), |
| | mse.round(5), |
| | em["curve"].max().round(3), |
| | fm["curve"].mean().round(3), |
| | wfm.round(3), |
| | em["adp"].round(3), |
| | fm["adp"].round(3), |
| | int(hce.round()), |
| | mba.round(3), |
| | biou["curve"].max().round(3), |
| | biou["curve"].mean().round(3), |
| | ] |
| | else: |
| | scores = [ |
| | sm.round(3), |
| | mae.round(3), |
| | em["curve"].max().round(3), |
| | em["curve"].mean().round(3), |
| | fm["curve"].max().round(3), |
| | fm["curve"].mean().round(3), |
| | wfm.round(3), |
| | em["adp"].round(3), |
| | fm["adp"].round(3), |
| | int(hce.round()), |
| | mba.round(3), |
| | biou["curve"].max().round(3), |
| | biou["curve"].mean().round(3), |
| | ] |
| |
|
| | for idx_score, score in enumerate(scores): |
| | scores[idx_score] = ( |
| | "." + format(score, ".3f").split(".")[-1] |
| | if score <= 1 |
| | else format(score, "<4") |
| | ) |
| | records = [_data_name, _model_name] + scores |
| | tb.add_row(records) |
| | |
| | with open(filename, "w+") as file_to_write: |
| | file_to_write.write(str(tb) + "\n") |
| | print(tb) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | |
| | parser = argparse.ArgumentParser() |
| | parser.add_argument( |
| | "--gt_root", |
| | type=str, |
| | help="ground-truth root", |
| | default=os.path.join(config.data_root_dir, config.task), |
| | ) |
| | parser.add_argument( |
| | "--pred_root", type=str, help="prediction root", default="./e_preds" |
| | ) |
| | parser.add_argument( |
| | "--data_lst", |
| | type=str, |
| | help="test dataset", |
| | default=config.testsets.replace(",", "+"), |
| | ) |
| | parser.add_argument( |
| | "--save_dir", type=str, help="candidate competitors", default="e_results" |
| | ) |
| | parser.add_argument( |
| | "--check_integrity", |
| | type=bool, |
| | help="whether to check the file integrity", |
| | default=False, |
| | ) |
| | parser.add_argument( |
| | "--metrics", |
| | type=str, |
| | help="candidate competitors", |
| | default="+".join( |
| | ["S", "MAE", "E", "F", "WF", "MBA", "BIoU", "MSE", "HCE"][ |
| | : 100 if "DIS5K" in config.task else -1 |
| | ] |
| | ), |
| | ) |
| | args = parser.parse_args() |
| | args.metrics = "+".join( |
| | ["S", "MAE", "E", "F", "WF", "MBA", "BIoU", "MSE", "HCE"][ |
| | : ( |
| | 100 |
| | if sum(["DIS-" in _data for _data in args.data_lst.split("+")]) |
| | else -1 |
| | ) |
| | ] |
| | ) |
| |
|
| | os.makedirs(args.save_dir, exist_ok=True) |
| | try: |
| | args.model_lst = [ |
| | m |
| | for m in sorted( |
| | os.listdir(args.pred_root), |
| | key=lambda x: int(x.split("epoch_")[-1]), |
| | reverse=True, |
| | ) |
| | if int(m.split("epoch_")[-1]) % 1 == 0 |
| | ] |
| | except: |
| | args.model_lst = [m for m in sorted(os.listdir(args.pred_root))] |
| |
|
| | |
| | if args.check_integrity: |
| | for _data_name in args.data_lst.split("+"): |
| | for _model_name in args.model_lst: |
| | gt_pth = os.path.join(args.gt_root, _data_name) |
| | pred_pth = os.path.join(args.pred_root, _model_name, _data_name) |
| | if not sorted(os.listdir(gt_pth)) == sorted(os.listdir(pred_pth)): |
| | print( |
| | len(sorted(os.listdir(gt_pth))), |
| | len(sorted(os.listdir(pred_pth))), |
| | ) |
| | print( |
| | "The {} Dataset of {} Model is not matching to the ground-truth".format( |
| | _data_name, _model_name |
| | ) |
| | ) |
| | else: |
| | print(">>> skip check the integrity of each candidates") |
| |
|
| | |
| | do_eval(args) |
| |
|