| | from glob import glob |
| | import argparse |
| | import os |
| | from typing import Tuple, List |
| | import numpy as np |
| | from mmeval import MeanIoU |
| | from PIL import Image |
| | from matplotlib import pyplot as plt |
| | from mmseg.apis import MMSegInferencer |
| | from vegseg.datasets import GrassDataset |
| | from vegseg import models |
| |
|
| |
|
| | def get_iou(pred: np.ndarray, gt: np.ndarray, num_classes=2): |
| | pred = pred[np.newaxis] |
| | gt = gt[np.newaxis] |
| | miou = MeanIoU(num_classes=num_classes) |
| | result = miou(pred, gt) |
| | return result["mIoU"] * 100 |
| |
|
| |
|
| | def get_args() -> Tuple[str, str, int]: |
| | """ |
| | get args |
| | return: |
| | --device: device to use. |
| | --dataset_path: dataset path. |
| | --output_path: output path for saving. |
| | """ |
| | parser = argparse.ArgumentParser() |
| | parser.add_argument("--device", type=str, default="cuda:4") |
| | parser.add_argument("--dataset_path", type=str, default="data/grass") |
| | args = parser.parse_args() |
| | return args.device, args.dataset_path |
| |
|
| |
|
| | def give_color_to_mask( |
| | mask: Image.Image | np.ndarray, palette: List[int] |
| | ) -> Image.Image: |
| | """ |
| | Args: |
| | mask: mask to color, numpy array or PIL Image. |
| | palette: palette of dataset. |
| | return: |
| | mask: mask with color. |
| | """ |
| | if isinstance(mask, np.ndarray): |
| | mask = Image.fromarray(mask) |
| | mask = mask.convert("P") |
| | mask.putpalette(palette) |
| | return mask |
| |
|
| |
|
| | def get_image_and_mask_paths( |
| | dataset_path: str, num: int |
| | ) -> Tuple[List[str], List[str]]: |
| | """ |
| | get image and mask paths from dataset path. |
| | return: |
| | image_paths: list of image paths. |
| | mask_paths: list of mask paths. |
| | """ |
| | image_paths = glob(os.path.join(dataset_path, "img_dir", "*", "*.tif")) |
| | if num != -1: |
| | image_paths = image_paths[:num] |
| | mask_paths = [ |
| | filename.replace("tif", "png").replace("img_dir", "ann_dir") |
| | for filename in image_paths |
| | ] |
| | return image_paths, mask_paths |
| |
|
| |
|
| | def get_palette() -> List[int]: |
| | """ |
| | get palette of dataset. |
| | return: |
| | palette: list of palette. |
| | """ |
| | palette = [] |
| | palette_list = GrassDataset.METAINFO["palette"] |
| | for palette_item in palette_list: |
| | palette.extend(palette_item) |
| | return palette |
| |
|
| |
|
| | def init_all_models(models_paths: List[str], device: str): |
| | """ |
| | init all models |
| | Args: |
| | models_path (str): path to all models. |
| | device (str): device to use. |
| | Return: |
| | models (dict): dict of models. |
| | """ |
| | models = {} |
| | for model_path in models_paths: |
| | config_path = glob(os.path.join(model_path, "*.py"))[0] |
| | weight_path = glob(os.path.join(model_path, "best_mIoU_iter_*.pth"))[0] |
| | inference = MMSegInferencer( |
| | config_path, |
| | weight_path, |
| | device=device, |
| | classes=GrassDataset.METAINFO["classes"], |
| | palette=GrassDataset.METAINFO["palette"], |
| | ) |
| | model_name = model_path.split(os.path.sep)[-1] |
| | models[model_name] = inference |
| | return models |
| |
|
| |
|
| | def main(): |
| | device, dataset_path = get_args() |
| | image_paths, mask_paths = get_image_and_mask_paths(dataset_path, -1) |
| | palette = get_palette() |
| | models_paths = [ |
| | r"work_dirs/fcn_r50", |
| | r"work_dirs/pspnet_r101", |
| | r"work_dirs/deeplabv3plus_r101", |
| | r"work_dirs/unet-s5-d16_deeplabv3", |
| | r"work_dirs/segformer_mit-b5", |
| | r"work_dirs/mask2former_swin_b", |
| | r"work_dirs/dinov2_upernet", |
| | r"work_dirs/experiment_p", |
| | ] |
| | models = init_all_models(models_paths, device) |
| |
|
| | model_order = [ |
| | "experiment_p", |
| | "fcn_r50", |
| | "pspnet_r101", |
| | "deeplabv3plus_r101", |
| | "unet-s5-d16_deeplabv3", |
| | "segformer_mit-b5", |
| | "mask2former_swin_b", |
| | "dinov2_upernet" |
| | ] |
| |
|
| | os.makedirs("vis_results", exist_ok=True) |
| | for image_path, mask_path in zip(image_paths, mask_paths): |
| | result_eval = {} |
| | result_iou = {} |
| | mask = Image.open(mask_path) |
| | for model_name, inference in models.items(): |
| | predictions: np.ndarray = inference(image_path)["predictions"] |
| | predictions = predictions.astype(np.uint8) |
| | result_eval[model_name] = predictions |
| | result_iou[model_name] = get_iou(predictions, np.array(mask), num_classes=5) |
| |
|
| | |
| | result_iou_sorted = sorted(result_iou.items(), key=lambda x: x[1], reverse=True) |
| |
|
| | if result_iou_sorted[0][0] != "experiment_p": |
| | continue |
| |
|
| | plt.figure(figsize=(32, 8)) |
| | plt.subplots_adjust(wspace=0.01) |
| | plt.subplot(1, 10, 1) |
| | plt.imshow(Image.open(image_path)) |
| | plt.axis("off") |
| |
|
| | plt.subplot(1, 10, 2) |
| | plt.imshow(give_color_to_mask(mask, palette=palette)) |
| | plt.axis("off") |
| |
|
| | for i, model_name in enumerate(model_order): |
| | plt.subplot(1, 10, i + 3) |
| | plt.imshow(give_color_to_mask(result_eval[model_name], palette)) |
| | plt.axis("off") |
| |
|
| | base_name = os.path.basename(image_path).split(".")[0] |
| | diff_iou = result_iou_sorted[0][1] - result_iou_sorted[1][1] |
| | plt.savefig( |
| | f"vis_results/{diff_iou:.2f}_{base_name}.svg", |
| | dpi=300, |
| | bbox_inches="tight", |
| | pad_inches=0, |
| | ) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | |
| | main() |
| |
|