XavierJiezou commited on
Commit
2dd4dfc
·
verified ·
1 Parent(s): c650234

Create vis_model_plus_save.py

Browse files
visualization/code/vis_model_plus_save.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from glob import glob
2
+ import argparse
3
+ import os
4
+ from typing import Tuple, List
5
+ import numpy as np
6
+ from mmeval import MeanIoU
7
+ from PIL import Image
8
+ from matplotlib import pyplot as plt
9
+ from mmseg.apis import MMSegInferencer
10
+ from vegseg.datasets import GrassDataset
11
+ from vegseg import models
12
+
13
+
14
+ def get_args() -> Tuple[str, str, int]:
15
+ """
16
+ get args
17
+ return:
18
+ --device: device to use.
19
+ --dataset_path: dataset path.
20
+ --output_path: output path for saving.
21
+ """
22
+ parser = argparse.ArgumentParser()
23
+ parser.add_argument("--device", type=str, default="cuda:4")
24
+ parser.add_argument("--dataset_path", type=str, default="data/grass")
25
+ args = parser.parse_args()
26
+ return args.device, args.dataset_path
27
+
28
+
29
+ def give_color_to_mask(
30
+ mask: Image.Image | np.ndarray, palette: List[int]
31
+ ) -> Image.Image:
32
+ """
33
+ Args:
34
+ mask: mask to color, numpy array or PIL Image.
35
+ palette: palette of dataset.
36
+ return:
37
+ mask: mask with color.
38
+ """
39
+ if isinstance(mask, np.ndarray):
40
+ mask = Image.fromarray(mask)
41
+ mask = mask.convert("P")
42
+ mask.putpalette(palette)
43
+ return mask
44
+
45
+
46
+ def get_image_and_mask_paths(
47
+ dataset_path: str, num: int
48
+ ) -> Tuple[List[str], List[str]]:
49
+ """
50
+ get image and mask paths from dataset path.
51
+ return:
52
+ image_paths: list of image paths.
53
+ mask_paths: list of mask paths.
54
+ """
55
+ image_paths = glob(os.path.join(dataset_path, "img_dir", "*", "*.tif"))
56
+ if num != -1:
57
+ image_paths = image_paths[:num]
58
+ mask_paths = [
59
+ filename.replace("tif", "png").replace("img_dir", "ann_dir")
60
+ for filename in image_paths
61
+ ]
62
+ return image_paths, mask_paths
63
+
64
+
65
+ def get_palette() -> List[int]:
66
+ """
67
+ get palette of dataset.
68
+ return:
69
+ palette: list of palette.
70
+ """
71
+ palette = []
72
+ palette_list = GrassDataset.METAINFO["palette"]
73
+ for palette_item in palette_list:
74
+ palette.extend(palette_item)
75
+ return palette
76
+
77
+
78
+ def init_all_models(models_paths: List[str], device: str):
79
+ """
80
+ init all models
81
+ Args:
82
+ models_path (str): path to all models.
83
+ device (str): device to use.
84
+ Return:
85
+ models (dict): dict of models.
86
+ """
87
+ models = {}
88
+ for model_path in models_paths:
89
+ print(model_path)
90
+ config_path = glob(os.path.join(model_path, "*.py"))[0]
91
+ weight_path = glob(os.path.join(model_path, "best_mIoU_iter_*.pth"))[0]
92
+ inference = MMSegInferencer(
93
+ config_path,
94
+ weight_path,
95
+ device=device,
96
+ classes=GrassDataset.METAINFO["classes"],
97
+ palette=GrassDataset.METAINFO["palette"],
98
+ )
99
+ model_name = model_path.split(os.path.sep)[-1]
100
+ models[model_name] = inference
101
+ return models
102
+
103
+
104
+ def main():
105
+ device, dataset_path = get_args()
106
+ image_paths, mask_paths = get_image_and_mask_paths(dataset_path, -1)
107
+ palette = get_palette()
108
+ models_paths = [
109
+ r"work_dirs/fcn_r50",
110
+ r"work_dirs/pspnet_r101",
111
+ r"work_dirs/deeplabv3plus_r101",
112
+ r"work_dirs/unet-s5-d16_deeplabv3",
113
+ r"work_dirs/segformer_mit-b5",
114
+ r"work_dirs/mask2former_swin_b",
115
+ r"work_dirs/dinov2_upernet",
116
+ r"work_dirs/experiment_p",
117
+ ]
118
+ models = init_all_models(models_paths, device)
119
+
120
+ model_order = [
121
+ "experiment_p",
122
+ "fcn_r50",
123
+ "pspnet_r101",
124
+ "deeplabv3plus_r101",
125
+ "unet-s5-d16_deeplabv3",
126
+ "segformer_mit-b5",
127
+ "mask2former_swin_b",
128
+ "dinov2_upernet"
129
+ ]
130
+
131
+ model_mapping = {
132
+ "experiment_p":"ktda",
133
+ "fcn_r50":"fcn",
134
+ "pspnet_r101":"pspnet",
135
+ "deeplabv3plus_r101":"deeplabv3plus",
136
+ "unet-s5-d16_deeplabv3":"unet",
137
+ "segformer_mit-b5":"segformer",
138
+ "mask2former_swin_b":"mask2former",
139
+ "dinov2_upernet":"dinov2"
140
+ }
141
+
142
+ os.makedirs("vis_results", exist_ok=True)
143
+ for model_name in model_order:
144
+ os.makedirs(f"data/visualization/grass/{model_name}", exist_ok=True)
145
+ os.makedirs(f"data/visualization/grass/input", exist_ok=True)
146
+ os.makedirs(f"data/visualization/grass/label", exist_ok=True)
147
+ for image_path, mask_path in zip(image_paths, mask_paths):
148
+ filename = os.path.basename(image_path)
149
+ for model_name, inference in models.items():
150
+ predictions: np.ndarray = inference(image_path)["predictions"]
151
+ predictions = predictions.astype(np.uint8)
152
+
153
+
154
+ predictions = give_color_to_mask(predictions, palette=palette)
155
+
156
+ predictions.save(f"data/visualization/grass/{model_name}/{filename}")
157
+
158
+
159
+ Image.open(image_path).save(f"data/visualization/grass/input/{filename}")
160
+ Image.open(mask_path).save(f"data/visualization/grass/label/{filename}")
161
+
162
+
163
+ if __name__ == "__main__":
164
+ # example usage: python tools/vis_model.py --models work_dirs --device cuda:0 --dataset_path data/grass
165
+ main()