| | import argparse |
| | import os |
| | from typing import List |
| |
|
| | import metrics |
| | import numpy as np |
| | import torch |
| | import torchvision.transforms.functional as TF |
| | from PIL import Image |
| | from sklearn.metrics import balanced_accuracy_score |
| | from torchvision.ops import masks_to_boxes |
| | from tqdm import tqdm |
| |
|
| |
|
| | def compute_size_boundry_centroid(binary_mask): |
| | is_empty = not np.any(binary_mask) |
| | H, W = binary_mask.shape |
| | if is_empty: |
| | return (0, 0), (H // 2, W // 2), (W // 2, W // 2, H // 2, H // 2) |
| | else: |
| | y, x = np.where(binary_mask == True) |
| | left_boundary = np.min(x) |
| | right_boundary = np.max(x) |
| | top_boundary = np.min(y) |
| | bottom_boundary = np.max(y) |
| |
|
| | centroid_x = int(left_boundary + right_boundary) // 2 |
| | centroid_y = int(top_boundary + bottom_boundary) // 2 |
| |
|
| | width, height = right_boundary - left_boundary + 1, bottom_boundary - top_boundary + 1 |
| |
|
| | return ( |
| | (width, height), |
| | (centroid_x, centroid_y), |
| | (left_boundary, right_boundary, top_boundary, bottom_boundary), |
| | ) |
| |
|
| |
|
| | def crop_mask(mask1, mask2): |
| | """ |
| | crop a pair of masks according to the size of the larger mask |
| | """ |
| | assert ( |
| | mask1.shape == mask2.shape |
| | ), f"Annotation({mask1.shape}) and segmentation:{mask2.shape} dimensions do not match." |
| |
|
| | mask1 = np.pad( |
| | mask1, |
| | ((mask1.shape[0], mask1.shape[0]), (mask1.shape[0], mask1.shape[0])), |
| | mode="constant", |
| | constant_values=False, |
| | ) |
| | mask2 = np.pad( |
| | mask2, |
| | ((mask2.shape[0], mask2.shape[0]), (mask2.shape[0], mask2.shape[0])), |
| | mode="constant", |
| | constant_values=False, |
| | ) |
| |
|
| | size_1, centroid_1, boundary_1 = compute_size_boundry_centroid(mask1) |
| | size_2, centroid_2, boundary_2 = compute_size_boundry_centroid(mask2) |
| |
|
| | width, height = max(size_1[0], size_2[0]), max(size_1[1], size_2[1]) |
| | |
| |
|
| | compact_mask_1 = mask1[ |
| | centroid_1[1] - height // 2 : centroid_1[1] + height // 2 + 1, |
| | centroid_1[0] - width // 2 : centroid_1[0] + width // 2 + 1, |
| | ] |
| | compact_mask_2 = mask2[ |
| | centroid_2[1] - height // 2 : centroid_2[1] + height // 2 + 1, |
| | centroid_2[0] - width // 2 : centroid_2[0] + width // 2 + 1, |
| | ] |
| |
|
| | return (size_1, size_2), (centroid_1, centroid_2), (compact_mask_1, compact_mask_2) |
| |
|
| |
|
| | def crop_resize_using_mask(img: torch.Tensor, mask: torch.Tensor, target_size: List[int]): |
| | """TODO: Docstring for crop_resize_using_mask. |
| | |
| | Args: |
| | img (torch.Tensor): shape: [C,H,W]. |
| | mask (torch.Tensor): shape: [H,W]. |
| | target_size (List[int]) |
| | |
| | Returns: TODO |
| | |
| | """ |
| | if mask.sum() == 0: |
| | return torch.ones([img.shape[0]] + list(target_size)) |
| |
|
| | bbox = masks_to_boxes(mask.unsqueeze(0)).squeeze(0) |
| | x1, y1, x2, y2 = bbox |
| | h = y2 - y1 |
| | w = x2 - x1 |
| | if h > 0 and w > 0: |
| | img = TF.crop(img, int(y1), int(x1), int(h), int(w)) |
| | img = TF.resize(img, target_size) |
| | return img |
| |
|
| |
|
| | def load_mask(mask_path: str, mode="L"): |
| | """load mask to numpy array |
| | |
| | Args: |
| | img_path (str): path to image. |
| | |
| | Returns: np.ndarray | None. dtype: np.bool. |
| | return None if file not exist or occurred some errors during loading. |
| | |
| | """ |
| | try: |
| | if not os.path.isfile(mask_path): |
| | print(f"file not existed for image: {mask_path}") |
| | return None |
| |
|
| | mask_pil = Image.open(mask_path).convert(mode) |
| | mask_pil = mask_pil.resize(IMG_SIZE, resample=Image.NEAREST) |
| | |
| | return np.array(mask_pil) > 127 |
| | except Exception as e: |
| | print(f"Exception while loading image: {mask_path}: {e}") |
| | return None |
| |
|
| |
|
| | IMG_SIZE = (256, 256) |
| | DIR_GT = "ground-truths" |
| | DIR_PRED = "predictions" |
| |
|
| |
|
| | def evaluate(results_dir): |
| | dir_gt = os.path.join(results_dir, DIR_GT) |
| | dir_pred = os.path.join(results_dir, DIR_PRED) |
| |
|
| | img_names = os.listdir(dir_gt) |
| | print(f"number of images: {len(img_names)}") |
| |
|
| | |
| |
|
| | |
| | mask_existence_gt_list = [] |
| | mask_existence_pred_list = [] |
| |
|
| | |
| | riou_score_list = [] |
| | rboundary_score_list = [] |
| | location_error_list = [] |
| |
|
| | for _, img_name in enumerate(tqdm(img_names)): |
| | if not img_name.endswith(".png"): |
| | continue |
| |
|
| | mask_gt = load_mask(os.path.join(dir_gt, img_name)) |
| | if mask_gt is None: |
| | continue |
| |
|
| | |
| | mask_existence_gt = np.any(mask_gt) |
| |
|
| | mask_pred = load_mask(os.path.join(dir_pred, img_name)) |
| |
|
| | if mask_pred is None: |
| | mask_existence_pred = False |
| | else: |
| | mask_existence_pred = np.any(mask_pred) |
| |
|
| | if not mask_existence_gt: |
| | mask_existence_gt_list.append(mask_existence_gt) |
| | mask_existence_pred_list.append(mask_existence_pred) |
| | |
| | continue |
| |
|
| | |
| | if mask_pred is not None: |
| | ( |
| | (gt_size, fake_size), |
| | (centroid_gt, centroid_pred), |
| | (gt_compact_mask, pred_compact_mask), |
| | ) = crop_mask(mask_gt, mask_pred) |
| | centroid_distance = np.sqrt( |
| | (centroid_gt[0] - centroid_pred[0]) ** 2 |
| | + (centroid_gt[1] - centroid_pred[1]) ** 2 |
| | ) |
| |
|
| | riou = metrics.db_eval_iou(gt_compact_mask, pred_compact_mask) |
| | rboundary = metrics.db_eval_boundary(gt_compact_mask, pred_compact_mask) |
| | location_error = centroid_distance / np.sqrt(IMG_SIZE[0] ** 2 + IMG_SIZE[1] ** 2) |
| | else: |
| | riou, rboundary, location_error = 0.0, 0.0, 1.0 |
| |
|
| | mask_existence_gt_list.append(mask_existence_gt) |
| | mask_existence_pred_list.append(mask_existence_pred) |
| | riou_score_list.append(riou) |
| | rboundary_score_list.append(rboundary) |
| | location_error_list.append(location_error) |
| |
|
| | |
| | balanced_mask_existence_acc = balanced_accuracy_score( |
| | mask_existence_gt_list, mask_existence_pred_list |
| | ) |
| | riou_score = np.mean(riou_score_list) |
| | rboundary_score = np.mean(rboundary_score_list) |
| | location_error = np.mean(location_error_list) |
| |
|
| | print(f"mask_existence balanced accuracy={balanced_mask_existence_acc}") |
| | print(f"Registered IOU score={riou_score}") |
| | print(f"Registered boundary score={rboundary_score}") |
| | print(f"Location Error={location_error}") |
| |
|
| |
|
| | if __name__ == "__main__": |
| |
|
| | parser = argparse.ArgumentParser() |
| | parser.add_argument("--results_dir", type=str, help="The directory of the results") |
| |
|
| | args = parser.parse_args() |
| |
|
| | evaluate(args.results_dir) |
| |
|