XavierJiezou commited on
Commit
c650234
·
verified ·
1 Parent(s): 28c6db0

Create vis_l8_组合.py

Browse files
Files changed (1) hide show
  1. visualization/code/vis_l8_组合.py +119 -0
visualization/code/vis_l8_组合.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from glob import glob
2
+ from mmeval import MeanIoU
3
+ from PIL import Image
4
+ import numpy as np
5
+ from typing import List
6
+ from vegseg.datasets import L8BIOMEDataset
7
+ from matplotlib import pyplot as plt
8
+ import os
9
+
10
+ def give_color_to_mask(
11
+ mask: Image.Image | np.ndarray, palette: List[int]
12
+ ) -> Image.Image:
13
+ """
14
+ Args:
15
+ mask: mask to color, numpy array or PIL Image.
16
+ palette: palette of dataset.
17
+ return:
18
+ mask: mask with color.
19
+ """
20
+ if isinstance(mask, np.ndarray):
21
+ mask = Image.fromarray(mask)
22
+ mask = mask.convert("P")
23
+ mask.putpalette(palette)
24
+ return mask
25
+
26
+ def get_iou(pred: np.ndarray, gt: np.ndarray, num_classes=2):
27
+ pred = pred[np.newaxis]
28
+ gt = gt[np.newaxis]
29
+ miou = MeanIoU(num_classes=num_classes)
30
+ result = miou(pred, gt)
31
+ return result["mIoU"] * 100
32
+
33
+
34
+ def get_palette() -> List[int]:
35
+ """
36
+ get palette of dataset.
37
+ return:
38
+ palette: list of palette.
39
+ """
40
+ palette = []
41
+ palette_list = L8BIOMEDataset.METAINFO["palette"]
42
+ for palette_item in palette_list:
43
+ palette.extend(palette_item)
44
+ return palette
45
+
46
+ def main():
47
+ ktda = glob("data/vis/ktda/*.png")
48
+
49
+ all_images = [
50
+ "cdnetv1",
51
+ "cdnetv2",
52
+ "hrcloudnet",
53
+ "input",
54
+ "kappamask",
55
+ "ktda",
56
+ "label",
57
+ "mcdnet",
58
+ "scnn",
59
+ "unetmobv2",
60
+ ]
61
+ model_order = [
62
+ "ktda",
63
+ "cdnetv1",
64
+ "cdnetv2",
65
+ "hrcloudnet",
66
+ "kappamask",
67
+ "mcdnet",
68
+ "scnn",
69
+ "unetmobv2",
70
+ ]
71
+ palette = get_palette()
72
+ for ktda_path in ktda:
73
+ images_paths = [
74
+ ktda_path.replace("ktda", filename) for filename in all_images
75
+ ]
76
+ model_name_mask = {}
77
+ model_iou = {}
78
+ label_path = ktda_path.replace("ktda", "label")
79
+ for image_path in images_paths:
80
+ model_name = image_path.split("/")[-2]
81
+ if model_name in ["input", "label"]:
82
+ continue
83
+ model_name_mask[model_name] = np.array(Image.open(image_path))
84
+ model_iou[model_name] = get_iou(
85
+ model_name_mask[model_name], np.array(Image.open(label_path)),num_classes=4
86
+ )
87
+ result_iou_sorted = sorted(model_iou.items(), key=lambda x: x[1], reverse=True)
88
+ if result_iou_sorted[0][0] != "ktda":
89
+ continue
90
+ input_path = ktda_path.replace("ktda", "input")
91
+
92
+ plt.figure(figsize=(32, 8))
93
+ plt.subplots_adjust(wspace=0.01)
94
+ plt.subplot(1, 10, 1)
95
+ plt.imshow(Image.open(input_path))
96
+ plt.axis("off")
97
+
98
+ plt.subplot(1, 10, 2)
99
+ plt.imshow(give_color_to_mask(Image.open(label_path), palette=palette))
100
+ plt.axis("off")
101
+
102
+ for i, model_name in enumerate(model_order):
103
+ plt.subplot(1, 10, i + 3)
104
+ plt.imshow(give_color_to_mask(model_name_mask[model_name], palette))
105
+ plt.axis("off")
106
+ base_name = os.path.basename(ktda_path).split(".")[0]
107
+ diff_iou = result_iou_sorted[0][1] - result_iou_sorted[1][1]
108
+ plt.savefig(
109
+ f"l8_vis/{diff_iou:.2f}_{base_name}.svg",
110
+ dpi=300,
111
+ bbox_inches="tight",
112
+ pad_inches=0,
113
+ )
114
+ plt.close()
115
+
116
+
117
+
118
+ if __name__ == "__main__":
119
+ main()