TuuSiwei commited on
Commit
f9a5fe7
·
1 Parent(s): 7e2f3c0

Add application file

Browse files
app.py ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ultralytics import YOLO
2
+ import gradio as gr
3
+ import torch
4
+ from utils.tools_gradio import fast_process
5
+ from utils.tools import format_results, box_prompt, point_prompt, text_prompt
6
+ from PIL import ImageDraw
7
+ import numpy as np
8
+
9
+ # Load the pre-trained model
10
+ model = YOLO('./weights/flashsam.pt')
11
+
12
+ device = torch.device(
13
+ "cuda"
14
+ if torch.cuda.is_available()
15
+ else "mps"
16
+ if torch.backends.mps.is_available()
17
+ else "cpu"
18
+ )
19
+
20
+ # Description
21
+ title = "<center><strong><font size='8'>🏃 FlashSAM 🤗</font></strong></center>"
22
+
23
+ news = """ # 📖 News
24
+ 🔥 2025/11/16: Release the first demo.
25
+ """
26
+
27
+ description_e = """This is a demo about FlashSAM.
28
+
29
+ 🎯 Upload an Image, segment it with Fast Segment Anything (Everything mode). The other modes will come soon.
30
+ """
31
+
32
+ description_p = """ # 🎯 Instructions for points mode
33
+ This is a demo about FlashSAM.
34
+
35
+ 1. Upload an image or choose an example.
36
+
37
+ 2. Choose the point label ('Add mask' means a positive point. 'Remove' Area means a negative point that is not segmented).
38
+
39
+ 3. Add points one by one on the image.
40
+
41
+ 4. Click the 'Segment with points prompt' button to get the segmentation results.
42
+
43
+ **5. If you get Error, click the 'Clear points' button and try again may help.**
44
+
45
+ """
46
+
47
+ examples = [["examples/sa_8776.jpg"], ["examples/sa_414.jpg"], ["examples/sa_1309.jpg"], ["examples/sa_11025.jpg"],
48
+ ["examples/sa_561.jpg"], ["examples/sa_192.jpg"], ["examples/sa_10039.jpg"], ["examples/sa_862.jpg"]]
49
+
50
+ default_example = examples[0]
51
+
52
+ css = "h1 { text-align: center } .about { text-align: justify; padding-left: 10%; padding-right: 10%; }"
53
+
54
+
55
+ def segment_everything(
56
+ input,
57
+ input_size=1024,
58
+ iou_threshold=0.7,
59
+ conf_threshold=0.25,
60
+ better_quality=False,
61
+ withContours=True,
62
+ use_retina=True,
63
+ text="",
64
+ wider=False,
65
+ mask_random_color=True,
66
+ ):
67
+ input_size = int(input_size) # 确保 imgsz 是整数
68
+ # Thanks for the suggestion by hysts in HuggingFace.
69
+ w, h = input.size
70
+ scale = input_size / max(w, h)
71
+ new_w = int(w * scale)
72
+ new_h = int(h * scale)
73
+ input = input.resize((new_w, new_h))
74
+
75
+ results = model(input,
76
+ device=device,
77
+ retina_masks=True,
78
+ iou=iou_threshold,
79
+ conf=conf_threshold,
80
+ imgsz=input_size,)
81
+
82
+ if len(text) > 0:
83
+ results = format_results(results[0], 0)
84
+ annotations, _ = text_prompt(results, text, input, device=device, wider=wider)
85
+ annotations = np.array([annotations])
86
+ else:
87
+ annotations = results[0].masks.data
88
+
89
+ fig = fast_process(annotations=annotations,
90
+ image=input,
91
+ device=device,
92
+ scale=(1024 // input_size),
93
+ better_quality=better_quality,
94
+ mask_random_color=mask_random_color,
95
+ bbox=None,
96
+ use_retina=use_retina,
97
+ withContours=withContours,)
98
+ return fig
99
+
100
+
101
+ def segment_with_points(
102
+ input,
103
+ input_size=1024,
104
+ iou_threshold=0.7,
105
+ conf_threshold=0.25,
106
+ better_quality=False,
107
+ withContours=True,
108
+ use_retina=True,
109
+ mask_random_color=True,
110
+ ):
111
+ global global_points
112
+ global global_point_label
113
+
114
+ input_size = int(input_size) # 确保 imgsz 是整数
115
+ # Thanks for the suggestion by hysts in HuggingFace.
116
+ w, h = input.size
117
+ scale = input_size / max(w, h)
118
+ new_w = int(w * scale)
119
+ new_h = int(h * scale)
120
+ input = input.resize((new_w, new_h))
121
+
122
+ scaled_points = [[int(x * scale) for x in point] for point in global_points]
123
+
124
+ results = model(input,
125
+ device=device,
126
+ retina_masks=True,
127
+ iou=iou_threshold,
128
+ conf=conf_threshold,
129
+ imgsz=input_size,)
130
+
131
+ results = format_results(results[0], 0)
132
+ annotations, _ = point_prompt(results, scaled_points, global_point_label, new_h, new_w)
133
+ annotations = np.array([annotations])
134
+
135
+ fig = fast_process(annotations=annotations,
136
+ image=input,
137
+ device=device,
138
+ scale=(1024 // input_size),
139
+ better_quality=better_quality,
140
+ mask_random_color=mask_random_color,
141
+ bbox=None,
142
+ use_retina=use_retina,
143
+ withContours=withContours,)
144
+
145
+ return fig
146
+
147
+
148
+ def get_points_with_draw(image, label, evt: gr.SelectData):
149
+ global global_points
150
+ global global_point_label
151
+
152
+ x, y = evt.index[0], evt.index[1]
153
+ point_radius, point_color = 15, (255, 255, 0) if label == 'Add Mask' else (255, 0, 255)
154
+ global_points.append([x, y])
155
+ global_point_label.append(1 if label == 'Add Mask' else 0)
156
+
157
+ print(x, y, label == 'Add Mask')
158
+
159
+ # 创建一个可以在图像上绘图的对象
160
+ draw = ImageDraw.Draw(image)
161
+ draw.ellipse([(x - point_radius, y - point_radius), (x + point_radius, y + point_radius)], fill=point_color)
162
+ return image
163
+
164
+
165
+ cond_img_e = gr.Image(label="Input", value=default_example[0], type='pil')
166
+ cond_img_p = gr.Image(label="Input with points", value=default_example[0], type='pil')
167
+ cond_img_t = gr.Image(label="Input with text", value="examples/dogs.jpg", type='pil')
168
+
169
+ segm_img_e = gr.Image(label="Segmented Image", interactive=False, type='pil')
170
+ segm_img_p = gr.Image(label="Segmented Image with points", interactive=False, type='pil')
171
+ segm_img_t = gr.Image(label="Segmented Image with text", interactive=False, type='pil')
172
+
173
+ global_points = []
174
+ global_point_label = []
175
+
176
+ input_size_slider = gr.components.Slider(minimum=512,
177
+ maximum=1024,
178
+ value=1024,
179
+ step=64,
180
+ label='Input_size',
181
+ info='Our model was trained on a size of 1024')
182
+
183
+ with gr.Blocks(css=css, title='FlashSAM') as demo:
184
+ with gr.Row():
185
+ with gr.Column(scale=1):
186
+ # Title
187
+ gr.Markdown(title)
188
+
189
+ with gr.Column(scale=1):
190
+ # News
191
+ gr.Markdown(news)
192
+
193
+ everything_tab = gr.Tab("Everything mode")
194
+ points_tab = gr.Tab("Points mode")
195
+ text_tab = gr.Tab("Text mode")
196
+
197
+ with everything_tab:
198
+ # Images
199
+ with gr.Row(variant="panel"):
200
+ with gr.Column(scale=1):
201
+ cond_img_e.render()
202
+
203
+ with gr.Column(scale=1):
204
+ segm_img_e.render()
205
+
206
+ # Submit & Clear
207
+ with gr.Row():
208
+ with gr.Column():
209
+ input_size_slider.render()
210
+
211
+ with gr.Row():
212
+ contour_check = gr.Checkbox(value=True, label='withContours', info='draw the edges of the masks')
213
+
214
+ with gr.Column():
215
+ segment_btn_e = gr.Button("Segment Everything", variant='primary')
216
+ clear_btn_e = gr.Button("Clear", variant="secondary")
217
+
218
+ gr.Markdown("Try some of the examples below ⬇️")
219
+ gr.Examples(examples=examples,
220
+ inputs=[cond_img_e],
221
+ outputs=segm_img_e,
222
+ fn=segment_everything,
223
+ cache_examples=True,
224
+ examples_per_page=4)
225
+
226
+ with gr.Column():
227
+ with gr.Accordion("Advanced options", open=False):
228
+ iou_threshold = gr.Slider(0.1, 0.9, 0.7, step=0.1, label='iou', info='iou threshold for filtering the annotations')
229
+ conf_threshold = gr.Slider(0.1, 0.9, 0.25, step=0.05, label='conf', info='object confidence threshold')
230
+ with gr.Row():
231
+ mor_check = gr.Checkbox(value=False, label='better_visual_quality', info='better quality using morphologyEx')
232
+ with gr.Column():
233
+ retina_check = gr.Checkbox(value=True, label='use_retina', info='draw high-resolution segmentation masks')
234
+
235
+ # Description
236
+ gr.Markdown(description_e)
237
+
238
+ segment_btn_e.click(segment_everything,
239
+ inputs=[
240
+ cond_img_e,
241
+ input_size_slider,
242
+ iou_threshold,
243
+ conf_threshold,
244
+ mor_check,
245
+ contour_check,
246
+ retina_check,
247
+ ],
248
+ outputs=segm_img_e)
249
+
250
+ with points_tab:
251
+ # Images
252
+ with gr.Row(variant="panel"):
253
+ with gr.Column(scale=1):
254
+ cond_img_p.render()
255
+
256
+ with gr.Column(scale=1):
257
+ segm_img_p.render()
258
+
259
+ # Submit & Clear
260
+ with gr.Row():
261
+ with gr.Column():
262
+ with gr.Row():
263
+ add_or_remove = gr.Radio(["Add Mask", "Remove Area"], value="Add Mask", label="Point_label (foreground/background)")
264
+
265
+ with gr.Column():
266
+ segment_btn_p = gr.Button("Segment with points prompt", variant='primary')
267
+ clear_btn_p = gr.Button("Clear points", variant='secondary')
268
+
269
+ gr.Markdown("Try some of the examples below ⬇️")
270
+ gr.Examples(examples=examples,
271
+ inputs=[cond_img_p],
272
+ # outputs=segm_img_p,
273
+ # fn=segment_with_points,
274
+ # cache_examples=True,
275
+ examples_per_page=4)
276
+
277
+ with gr.Column():
278
+ # Description
279
+ gr.Markdown(description_p)
280
+
281
+ cond_img_p.select(get_points_with_draw, [cond_img_p, add_or_remove], cond_img_p)
282
+
283
+ segment_btn_p.click(segment_with_points,
284
+ inputs=[cond_img_p],
285
+ outputs=[segm_img_p])
286
+
287
+ with text_tab:
288
+ # Images
289
+ with gr.Row(variant="panel"):
290
+ with gr.Column(scale=1):
291
+ cond_img_t.render()
292
+
293
+ with gr.Column(scale=1):
294
+ segm_img_t.render()
295
+
296
+ # Submit & Clear
297
+ with gr.Row():
298
+ with gr.Column():
299
+ input_size_slider_t = gr.components.Slider(minimum=512,
300
+ maximum=1024,
301
+ value=1024,
302
+ step=64,
303
+ label='Input_size',
304
+ info='Our model was trained on a size of 1024')
305
+ with gr.Row():
306
+ with gr.Column():
307
+ contour_check = gr.Checkbox(value=True, label='withContours', info='draw the edges of the masks')
308
+ text_box = gr.Textbox(label="text prompt", value="a yellow dog")
309
+
310
+ with gr.Column():
311
+ segment_btn_t = gr.Button("Segment with text", variant='primary')
312
+ clear_btn_t = gr.Button("Clear", variant="secondary")
313
+
314
+ gr.Markdown("Try some of the examples below ⬇️")
315
+ gr.Examples(examples=[["examples/dogs.jpg"], ["examples/fruits.jpg"], ["examples/flowers.jpg"]],
316
+ inputs=[cond_img_t],
317
+ # outputs=segm_img_e,
318
+ # fn=segment_everything,
319
+ # cache_examples=True,
320
+ examples_per_page=4)
321
+
322
+ with gr.Column():
323
+ with gr.Accordion("Advanced options", open=False):
324
+ iou_threshold = gr.Slider(0.1, 0.9, 0.7, step=0.1, label='iou', info='iou threshold for filtering the annotations')
325
+ conf_threshold = gr.Slider(0.1, 0.9, 0.25, step=0.05, label='conf', info='object confidence threshold')
326
+ with gr.Row():
327
+ mor_check = gr.Checkbox(value=False, label='better_visual_quality', info='better quality using morphologyEx')
328
+ retina_check = gr.Checkbox(value=True, label='use_retina', info='draw high-resolution segmentation masks')
329
+ wider_check = gr.Checkbox(value=False, label='wider', info='wider result')
330
+
331
+ # Description
332
+ gr.Markdown(description_e)
333
+
334
+ segment_btn_t.click(segment_everything,
335
+ inputs=[
336
+ cond_img_t,
337
+ input_size_slider_t,
338
+ iou_threshold,
339
+ conf_threshold,
340
+ mor_check,
341
+ contour_check,
342
+ retina_check,
343
+ text_box,
344
+ wider_check,
345
+ ],
346
+ outputs=segm_img_t)
347
+
348
+ def clear():
349
+ global global_points
350
+ global global_point_label
351
+ global_points = []
352
+ global_point_label = []
353
+ return None, None
354
+
355
+ def clear_text():
356
+ return None, None, None
357
+
358
+ everything_tab.select(clear, outputs=[cond_img_e, segm_img_e]) # reset when everything tab is selected
359
+ points_tab.select(clear, outputs=[cond_img_e, segm_img_e]) # reset when points tab is selected
360
+ text_tab.select(clear, outputs=[cond_img_e, segm_img_e]) # reset when text tab is selected
361
+ cond_img_p.clear(clear, outputs=[cond_img_e, segm_img_e]) # reset when input image is cleared
362
+ cond_img_p.input(clear, outputs=[cond_img_e, segm_img_e]) # reset when input image is changed
363
+ clear_btn_e.click(clear, outputs=[cond_img_e, segm_img_e])
364
+ clear_btn_p.click(clear, outputs=[cond_img_p, segm_img_p])
365
+ clear_btn_t.click(clear_text, outputs=[cond_img_p, segm_img_p, text_box])
366
+
367
+ demo.queue()
368
+ demo.launch(share=True)
requirements.txt ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==24.1.0
2
+ aliyun-python-sdk-core==2.16.0
3
+ aliyun-python-sdk-kms==2.16.5
4
+ annotated-doc==0.0.4
5
+ annotated-types==0.7.0
6
+ anyio==4.11.0
7
+ brotli==1.2.0
8
+ certifi==2025.11.12
9
+ cffi==1.17.1
10
+ charset-normalizer==3.4.3
11
+ click==8.3.1
12
+ clip==1.0
13
+ contourpy==1.3.2
14
+ crcmod==1.7
15
+ cryptography==45.0.7
16
+ cycler==0.12.1
17
+ exceptiongroup==1.3.0
18
+ fastapi==0.121.2
19
+ faster-coco-eval==1.6.7
20
+ ffmpy==1.0.0
21
+ filelock==3.20.0
22
+ fonttools==4.59.2
23
+ fsspec==2025.10.0
24
+ ftfy==6.3.1
25
+ gradio==5.49.1
26
+ gradio-client==1.13.3
27
+ groovy==0.1.2
28
+ h11==0.16.0
29
+ hf-xet==1.2.0
30
+ httpcore==1.0.9
31
+ httpx==0.28.1
32
+ huggingface-hub==1.1.4
33
+ idna==3.11
34
+ jinja2==3.1.6
35
+ jmespath==0.10.0
36
+ kiwisolver==1.4.9
37
+ markdown-it-py==4.0.0
38
+ markupsafe==3.0.3
39
+ matplotlib==3.7.0
40
+ mdurl==0.1.2
41
+ mpmath==1.3.0
42
+ narwhals==2.2.0
43
+ networkx==3.4.2
44
+ numpy==1.26.4
45
+ opencv-python==4.11.0.86
46
+ openxlab==0.1.2
47
+ orjson==3.11.4
48
+ oss2==2.17.0
49
+ packaging==25.0
50
+ pandas==2.3.3
51
+ pillow==11.3.0
52
+ plotly==6.3.0
53
+ polars==1.32.3
54
+ psutil==7.0.0
55
+ py-cpuinfo==9.0.0
56
+ pycocotools==2.0.7
57
+ pycparser==2.22
58
+ pycryptodome==3.23.0
59
+ pydantic==2.11.10
60
+ pydantic-core==2.33.2
61
+ pydub==0.25.1
62
+ pygments==2.19.2
63
+ pyparsing==3.2.3
64
+ python-dateutil==2.9.0.post0
65
+ python-multipart==0.0.20
66
+ pytz==2025.2
67
+ pyyaml==6.0.3
68
+ regex==2025.11.3
69
+ requests==2.28.2
70
+ rich==14.2.0
71
+ ruff==0.14.5
72
+ safehttpx==0.1.7
73
+ scipy==1.15.3
74
+ semantic-version==2.10.0
75
+ setuptools==60.2.0
76
+ shellingham==1.5.4
77
+ six==1.17.0
78
+ sniffio==1.3.1
79
+ starlette==0.49.3
80
+ sympy==1.14.0
81
+ tomlkit==0.13.3
82
+ torch==2.8.0
83
+ torchvision==0.23.0
84
+ tqdm==4.67.1
85
+ typer==0.20.0
86
+ typer-slim==0.20.0
87
+ typing-extensions==4.15.0
88
+ typing-inspection==0.4.2
89
+ tzdata==2025.2
90
+ ultralytics==8.3.191
91
+ ultralytics-thop==2.0.16
92
+ urllib3==1.26.20
93
+ uvicorn==0.38.0
94
+ wcwidth==0.2.14
95
+ websockets==15.0.1
utils/__init__.py ADDED
File without changes
utils/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (135 Bytes). View file
 
utils/__pycache__/tools.cpython-310.pyc ADDED
Binary file (12.3 kB). View file
 
utils/__pycache__/tools_gradio.cpython-310.pyc ADDED
Binary file (4.07 kB). View file
 
utils/tools.py ADDED
@@ -0,0 +1,442 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from PIL import Image
3
+ import matplotlib.pyplot as plt
4
+ import cv2
5
+ import torch
6
+ import os
7
+ import sys
8
+ import clip
9
+
10
+
11
+ def convert_box_xywh_to_xyxy(box):
12
+ if len(box) == 4:
13
+ return [box[0], box[1], box[0] + box[2], box[1] + box[3]]
14
+ else:
15
+ result = []
16
+ for b in box:
17
+ b = convert_box_xywh_to_xyxy(b)
18
+ result.append(b)
19
+ return result
20
+
21
+
22
+ def segment_image(image, bbox):
23
+ image_array = np.array(image)
24
+ segmented_image_array = np.zeros_like(image_array)
25
+ x1, y1, x2, y2 = bbox
26
+ segmented_image_array[y1:y2, x1:x2] = image_array[y1:y2, x1:x2]
27
+ segmented_image = Image.fromarray(segmented_image_array)
28
+ black_image = Image.new("RGB", image.size, (255, 255, 255))
29
+ # transparency_mask = np.zeros_like((), dtype=np.uint8)
30
+ transparency_mask = np.zeros(
31
+ (image_array.shape[0], image_array.shape[1]), dtype=np.uint8
32
+ )
33
+ transparency_mask[y1:y2, x1:x2] = 255
34
+ transparency_mask_image = Image.fromarray(transparency_mask, mode="L")
35
+ black_image.paste(segmented_image, mask=transparency_mask_image)
36
+ return black_image
37
+
38
+
39
+ def format_results(result, filter=0):
40
+ annotations = []
41
+ n = len(result.masks.data)
42
+ for i in range(n):
43
+ annotation = {}
44
+ mask = result.masks.data[i] == 1.0
45
+
46
+ if torch.sum(mask) < filter:
47
+ continue
48
+ annotation["id"] = i
49
+ annotation["segmentation"] = mask.cpu().numpy()
50
+ annotation["bbox"] = result.boxes.data[i]
51
+ annotation["score"] = result.boxes.conf[i]
52
+ annotation["area"] = annotation["segmentation"].sum()
53
+ annotations.append(annotation)
54
+ return annotations
55
+
56
+
57
+ def filter_masks(annotations): # filter the overlap mask
58
+ annotations.sort(key=lambda x: x["area"], reverse=True)
59
+ to_remove = set()
60
+ for i in range(0, len(annotations)):
61
+ a = annotations[i]
62
+ for j in range(i + 1, len(annotations)):
63
+ b = annotations[j]
64
+ if i != j and j not in to_remove:
65
+ # check if
66
+ if b["area"] < a["area"]:
67
+ if (a["segmentation"] & b["segmentation"]).sum() / b[
68
+ "segmentation"
69
+ ].sum() > 0.8:
70
+ to_remove.add(j)
71
+
72
+ return [a for i, a in enumerate(annotations) if i not in to_remove], to_remove
73
+
74
+
75
+ def get_bbox_from_mask(mask):
76
+ mask = mask.astype(np.uint8)
77
+ contours, hierarchy = cv2.findContours(
78
+ mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
79
+ )
80
+ x1, y1, w, h = cv2.boundingRect(contours[0])
81
+ x2, y2 = x1 + w, y1 + h
82
+ if len(contours) > 1:
83
+ for b in contours:
84
+ x_t, y_t, w_t, h_t = cv2.boundingRect(b)
85
+ # 将多个bbox合并成一个
86
+ x1 = min(x1, x_t)
87
+ y1 = min(y1, y_t)
88
+ x2 = max(x2, x_t + w_t)
89
+ y2 = max(y2, y_t + h_t)
90
+ h = y2 - y1
91
+ w = x2 - x1
92
+ return [x1, y1, x2, y2]
93
+
94
+
95
+ def fast_process(
96
+ annotations, args, mask_random_color, bbox=None, points=None, edges=False
97
+ ):
98
+ if isinstance(annotations[0], dict):
99
+ annotations = [annotation["segmentation"] for annotation in annotations]
100
+ result_name = os.path.basename(args.img_path)
101
+ image = cv2.imread(args.img_path)
102
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
103
+ original_h = image.shape[0]
104
+ original_w = image.shape[1]
105
+ if sys.platform == "darwin":
106
+ plt.switch_backend("TkAgg")
107
+ plt.figure(figsize=(original_w/100, original_h/100))
108
+ # Add subplot with no margin.
109
+ plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
110
+ plt.margins(0, 0)
111
+ plt.gca().xaxis.set_major_locator(plt.NullLocator())
112
+ plt.gca().yaxis.set_major_locator(plt.NullLocator())
113
+ plt.imshow(image)
114
+ if args.better_quality == True:
115
+ if isinstance(annotations[0], torch.Tensor):
116
+ annotations = np.array(annotations.cpu())
117
+ for i, mask in enumerate(annotations):
118
+ mask = cv2.morphologyEx(
119
+ mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8)
120
+ )
121
+ annotations[i] = cv2.morphologyEx(
122
+ mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((8, 8), np.uint8)
123
+ )
124
+ if args.device == "cpu":
125
+ annotations = np.array(annotations)
126
+ fast_show_mask(
127
+ annotations,
128
+ plt.gca(),
129
+ random_color=mask_random_color,
130
+ bbox=bbox,
131
+ points=points,
132
+ point_label=args.point_label,
133
+ retinamask=args.retina,
134
+ target_height=original_h,
135
+ target_width=original_w,
136
+ )
137
+ else:
138
+ if isinstance(annotations[0], np.ndarray):
139
+ annotations = torch.from_numpy(annotations)
140
+ fast_show_mask_gpu(
141
+ annotations,
142
+ plt.gca(),
143
+ random_color=args.randomcolor,
144
+ bbox=bbox,
145
+ points=points,
146
+ point_label=args.point_label,
147
+ retinamask=args.retina,
148
+ target_height=original_h,
149
+ target_width=original_w,
150
+ )
151
+ if isinstance(annotations, torch.Tensor):
152
+ annotations = annotations.cpu().numpy()
153
+ if args.withContours == True:
154
+ contour_all = []
155
+ temp = np.zeros((original_h, original_w, 1))
156
+ for i, mask in enumerate(annotations):
157
+ if type(mask) == dict:
158
+ mask = mask["segmentation"]
159
+ annotation = mask.astype(np.uint8)
160
+ if args.retina == False:
161
+ annotation = cv2.resize(
162
+ annotation,
163
+ (original_w, original_h),
164
+ interpolation=cv2.INTER_NEAREST,
165
+ )
166
+ contours, hierarchy = cv2.findContours(
167
+ annotation, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE
168
+ )
169
+ for contour in contours:
170
+ contour_all.append(contour)
171
+ cv2.drawContours(temp, contour_all, -1, (255, 255, 255), 2)
172
+ color = np.array([0 / 255, 0 / 255, 255 / 255, 0.8])
173
+ contour_mask = temp / 255 * color.reshape(1, 1, -1)
174
+ plt.imshow(contour_mask)
175
+
176
+ save_path = args.output
177
+ if not os.path.exists(save_path):
178
+ os.makedirs(save_path)
179
+ plt.axis("off")
180
+ fig = plt.gcf()
181
+ plt.draw()
182
+
183
+ try:
184
+ buf = fig.canvas.tostring_rgb()
185
+ except AttributeError:
186
+ fig.canvas.draw()
187
+ buf = fig.canvas.tostring_rgb()
188
+
189
+ cols, rows = fig.canvas.get_width_height()
190
+ img_array = np.fromstring(buf, dtype=np.uint8).reshape(rows, cols, 3)
191
+ cv2.imwrite(os.path.join(save_path, result_name), cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR))
192
+
193
+
194
+ # CPU post process
195
+ def fast_show_mask(
196
+ annotation,
197
+ ax,
198
+ random_color=False,
199
+ bbox=None,
200
+ points=None,
201
+ point_label=None,
202
+ retinamask=True,
203
+ target_height=960,
204
+ target_width=960,
205
+ ):
206
+ msak_sum = annotation.shape[0]
207
+ height = annotation.shape[1]
208
+ weight = annotation.shape[2]
209
+ # 将annotation 按照面积 排序
210
+ areas = np.sum(annotation, axis=(1, 2))
211
+ sorted_indices = np.argsort(areas)
212
+ annotation = annotation[sorted_indices]
213
+
214
+ index = (annotation != 0).argmax(axis=0)
215
+ if random_color == True:
216
+ color = np.random.random((msak_sum, 1, 1, 3))
217
+ else:
218
+ color = np.ones((msak_sum, 1, 1, 3)) * np.array(
219
+ [30 / 255, 144 / 255, 255 / 255]
220
+ )
221
+ transparency = np.ones((msak_sum, 1, 1, 1)) * 0.6
222
+ visual = np.concatenate([color, transparency], axis=-1)
223
+ mask_image = np.expand_dims(annotation, -1) * visual
224
+
225
+ show = np.zeros((height, weight, 4))
226
+ h_indices, w_indices = np.meshgrid(
227
+ np.arange(height), np.arange(weight), indexing="ij"
228
+ )
229
+ indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
230
+ # 使用向量化索引更新show的值
231
+ show[h_indices, w_indices, :] = mask_image[indices]
232
+ if bbox is not None:
233
+ x1, y1, x2, y2 = bbox
234
+ ax.add_patch(
235
+ plt.Rectangle(
236
+ (x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor="b", linewidth=1
237
+ )
238
+ )
239
+ # draw point
240
+ if points is not None:
241
+ plt.scatter(
242
+ [point[0] for i, point in enumerate(points) if point_label[i] == 1],
243
+ [point[1] for i, point in enumerate(points) if point_label[i] == 1],
244
+ s=20,
245
+ c="y",
246
+ )
247
+ plt.scatter(
248
+ [point[0] for i, point in enumerate(points) if point_label[i] == 0],
249
+ [point[1] for i, point in enumerate(points) if point_label[i] == 0],
250
+ s=20,
251
+ c="m",
252
+ )
253
+
254
+ if retinamask == False:
255
+ show = cv2.resize(
256
+ show, (target_width, target_height), interpolation=cv2.INTER_NEAREST
257
+ )
258
+ ax.imshow(show)
259
+
260
+
261
+ def fast_show_mask_gpu(
262
+ annotation,
263
+ ax,
264
+ random_color=False,
265
+ bbox=None,
266
+ points=None,
267
+ point_label=None,
268
+ retinamask=True,
269
+ target_height=960,
270
+ target_width=960,
271
+ ):
272
+ msak_sum = annotation.shape[0]
273
+ height = annotation.shape[1]
274
+ weight = annotation.shape[2]
275
+ areas = torch.sum(annotation, dim=(1, 2))
276
+ sorted_indices = torch.argsort(areas, descending=False)
277
+ annotation = annotation[sorted_indices]
278
+ # 找每个位置第一个非零值下标
279
+ index = (annotation != 0).to(torch.long).argmax(dim=0)
280
+ if random_color == True:
281
+ color = torch.rand((msak_sum, 1, 1, 3)).to(annotation.device)
282
+ else:
283
+ color = torch.ones((msak_sum, 1, 1, 3)).to(annotation.device) * torch.tensor(
284
+ [30 / 255, 144 / 255, 255 / 255]
285
+ ).to(annotation.device)
286
+ transparency = torch.ones((msak_sum, 1, 1, 1)).to(annotation.device) * 0.6
287
+ visual = torch.cat([color, transparency], dim=-1)
288
+ mask_image = torch.unsqueeze(annotation, -1) * visual
289
+ # 按index取数,index指每个位置选哪个batch的数,把mask_image转成一个batch的形式
290
+ show = torch.zeros((height, weight, 4)).to(annotation.device)
291
+ h_indices, w_indices = torch.meshgrid(
292
+ torch.arange(height), torch.arange(weight), indexing="ij"
293
+ )
294
+ indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
295
+ # 使用向量化索引更新show的值
296
+ show[h_indices, w_indices, :] = mask_image[indices]
297
+ show_cpu = show.cpu().numpy()
298
+ if bbox is not None:
299
+ x1, y1, x2, y2 = bbox
300
+ ax.add_patch(
301
+ plt.Rectangle(
302
+ (x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor="b", linewidth=1
303
+ )
304
+ )
305
+ # draw point
306
+ if points is not None:
307
+ plt.scatter(
308
+ [point[0] for i, point in enumerate(points) if point_label[i] == 1],
309
+ [point[1] for i, point in enumerate(points) if point_label[i] == 1],
310
+ s=20,
311
+ c="y",
312
+ )
313
+ plt.scatter(
314
+ [point[0] for i, point in enumerate(points) if point_label[i] == 0],
315
+ [point[1] for i, point in enumerate(points) if point_label[i] == 0],
316
+ s=20,
317
+ c="m",
318
+ )
319
+ if retinamask == False:
320
+ show_cpu = cv2.resize(
321
+ show_cpu, (target_width, target_height), interpolation=cv2.INTER_NEAREST
322
+ )
323
+ ax.imshow(show_cpu)
324
+
325
+
326
+ # clip
327
+ @torch.no_grad()
328
+ def retriev(
329
+ model, preprocess, elements: [Image.Image], search_text: str, device
330
+ ):
331
+ preprocessed_images = [preprocess(image).to(device) for image in elements]
332
+ tokenized_text = clip.tokenize([search_text]).to(device)
333
+ stacked_images = torch.stack(preprocessed_images)
334
+ image_features = model.encode_image(stacked_images)
335
+ text_features = model.encode_text(tokenized_text)
336
+ image_features /= image_features.norm(dim=-1, keepdim=True)
337
+ text_features /= text_features.norm(dim=-1, keepdim=True)
338
+ probs = 100.0 * image_features @ text_features.T
339
+ return probs[:, 0].softmax(dim=0)
340
+
341
+
342
+ def crop_image(annotations, image_like):
343
+ if isinstance(image_like, str):
344
+ image = Image.open(image_like)
345
+ else:
346
+ image = image_like
347
+ ori_w, ori_h = image.size
348
+ mask_h, mask_w = annotations[0]["segmentation"].shape
349
+ if ori_w != mask_w or ori_h != mask_h:
350
+ image = image.resize((mask_w, mask_h))
351
+ cropped_boxes = []
352
+ cropped_images = []
353
+ not_crop = []
354
+ origin_id = []
355
+ for _, mask in enumerate(annotations):
356
+ if np.sum(mask["segmentation"]) <= 100:
357
+ continue
358
+ origin_id.append(_)
359
+ bbox = get_bbox_from_mask(mask["segmentation"]) # mask 的 bbox
360
+ cropped_boxes.append(segment_image(image, bbox)) # 保存裁剪的图片
361
+ # cropped_boxes.append(segment_image(image,mask["segmentation"]))
362
+ cropped_images.append(bbox) # 保存裁剪的图片的bbox
363
+ return cropped_boxes, cropped_images, not_crop, origin_id, annotations
364
+
365
+
366
+ def box_prompt(masks, bbox, target_height, target_width):
367
+ h = masks.shape[1]
368
+ w = masks.shape[2]
369
+ if h != target_height or w != target_width:
370
+ bbox = [
371
+ int(bbox[0] * w / target_width),
372
+ int(bbox[1] * h / target_height),
373
+ int(bbox[2] * w / target_width),
374
+ int(bbox[3] * h / target_height),
375
+ ]
376
+ bbox[0] = round(bbox[0]) if round(bbox[0]) > 0 else 0
377
+ bbox[1] = round(bbox[1]) if round(bbox[1]) > 0 else 0
378
+ bbox[2] = round(bbox[2]) if round(bbox[2]) < w else w
379
+ bbox[3] = round(bbox[3]) if round(bbox[3]) < h else h
380
+
381
+ # IoUs = torch.zeros(len(masks), dtype=torch.float32)
382
+ bbox_area = (bbox[3] - bbox[1]) * (bbox[2] - bbox[0])
383
+
384
+ masks_area = torch.sum(masks[:, bbox[1] : bbox[3], bbox[0] : bbox[2]], dim=(1, 2))
385
+ orig_masks_area = torch.sum(masks, dim=(1, 2))
386
+
387
+ union = bbox_area + orig_masks_area - masks_area
388
+ IoUs = masks_area / union
389
+ max_iou_index = torch.argmax(IoUs)
390
+
391
+ return masks[max_iou_index].cpu().numpy(), max_iou_index
392
+
393
+
394
+ def point_prompt(masks, points, point_label, target_height, target_width): # numpy 处理
395
+ h = masks[0]["segmentation"].shape[0]
396
+ w = masks[0]["segmentation"].shape[1]
397
+ if h != target_height or w != target_width:
398
+ points = [
399
+ [int(point[0] * w / target_width), int(point[1] * h / target_height)]
400
+ for point in points
401
+ ]
402
+ onemask = np.zeros((h, w))
403
+ masks = sorted(masks, key=lambda x: x['area'], reverse=True)
404
+ for i, annotation in enumerate(masks):
405
+ if type(annotation) == dict:
406
+ mask = annotation['segmentation']
407
+ else:
408
+ mask = annotation
409
+ for i, point in enumerate(points):
410
+ if mask[point[1], point[0]] == 1 and point_label[i] == 1:
411
+ onemask[mask] = 1
412
+ if mask[point[1], point[0]] == 1 and point_label[i] == 0:
413
+ onemask[mask] = 0
414
+ onemask = onemask >= 1
415
+ return onemask, 0
416
+
417
+
418
+ def text_prompt(annotations, text, img_path, device, wider=False, threshold=0.9):
419
+ cropped_boxes, cropped_images, not_crop, origin_id, annotations_ = crop_image(
420
+ annotations, img_path
421
+ )
422
+ clip_model, preprocess = clip.load("./weights/CLIP_ViT_B_32.pt", device=device)
423
+ scores = retriev(
424
+ clip_model, preprocess, cropped_boxes, text, device=device
425
+ )
426
+ max_idx = scores.argsort()
427
+ max_idx = max_idx[-1]
428
+ max_idx = origin_id[int(max_idx)]
429
+
430
+ # find the biggest mask which contains the mask with max score
431
+ if wider:
432
+ mask0 = annotations_[max_idx]["segmentation"]
433
+ area0 = np.sum(mask0)
434
+ areas = [(i, np.sum(mask["segmentation"])) for i, mask in enumerate(annotations_) if i in origin_id]
435
+ areas = sorted(areas, key=lambda area: area[1], reverse=True)
436
+ indices = [area[0] for area in areas]
437
+ for index in indices:
438
+ if index == max_idx or np.sum(annotations_[index]["segmentation"] & mask0) / area0 > threshold:
439
+ max_idx = index
440
+ break
441
+
442
+ return annotations_[max_idx]["segmentation"], max_idx
utils/tools_gradio.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from PIL import Image
3
+ import matplotlib.pyplot as plt
4
+ import cv2
5
+ import torch
6
+
7
+
8
+ def fast_process(
9
+ annotations,
10
+ image,
11
+ device,
12
+ scale,
13
+ better_quality=False,
14
+ mask_random_color=True,
15
+ bbox=None,
16
+ use_retina=True,
17
+ withContours=True,
18
+ ):
19
+ if isinstance(annotations[0], dict):
20
+ annotations = [annotation['segmentation'] for annotation in annotations]
21
+
22
+ original_h = image.height
23
+ original_w = image.width
24
+ if better_quality:
25
+ if isinstance(annotations[0], torch.Tensor):
26
+ annotations = np.array(annotations.cpu())
27
+ for i, mask in enumerate(annotations):
28
+ mask = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8))
29
+ annotations[i] = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((8, 8), np.uint8))
30
+ if device == 'cpu':
31
+ annotations = np.array(annotations)
32
+ inner_mask = fast_show_mask(
33
+ annotations,
34
+ plt.gca(),
35
+ random_color=mask_random_color,
36
+ bbox=bbox,
37
+ retinamask=use_retina,
38
+ target_height=original_h,
39
+ target_width=original_w,
40
+ )
41
+ else:
42
+ if isinstance(annotations[0], np.ndarray):
43
+ annotations = torch.from_numpy(annotations)
44
+ inner_mask = fast_show_mask_gpu(
45
+ annotations,
46
+ plt.gca(),
47
+ random_color=mask_random_color,
48
+ bbox=bbox,
49
+ retinamask=use_retina,
50
+ target_height=original_h,
51
+ target_width=original_w,
52
+ )
53
+ if isinstance(annotations, torch.Tensor):
54
+ annotations = annotations.cpu().numpy()
55
+
56
+ if withContours:
57
+ contour_all = []
58
+ temp = np.zeros((original_h, original_w, 1))
59
+ for i, mask in enumerate(annotations):
60
+ if type(mask) == dict:
61
+ mask = mask['segmentation']
62
+ annotation = mask.astype(np.uint8)
63
+ if use_retina == False:
64
+ annotation = cv2.resize(
65
+ annotation,
66
+ (original_w, original_h),
67
+ interpolation=cv2.INTER_NEAREST,
68
+ )
69
+ contours, _ = cv2.findContours(annotation, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
70
+ for contour in contours:
71
+ contour_all.append(contour)
72
+ cv2.drawContours(temp, contour_all, -1, (255, 255, 255), 2 // scale)
73
+ color = np.array([0 / 255, 0 / 255, 255 / 255, 0.9])
74
+ contour_mask = temp / 255 * color.reshape(1, 1, -1)
75
+
76
+ image = image.convert('RGBA')
77
+ overlay_inner = Image.fromarray((inner_mask * 255).astype(np.uint8), 'RGBA')
78
+ image.paste(overlay_inner, (0, 0), overlay_inner)
79
+
80
+ if withContours:
81
+ overlay_contour = Image.fromarray((contour_mask * 255).astype(np.uint8), 'RGBA')
82
+ image.paste(overlay_contour, (0, 0), overlay_contour)
83
+
84
+ return image
85
+
86
+
87
+ # CPU post process
88
+ def fast_show_mask(
89
+ annotation,
90
+ ax,
91
+ random_color=False,
92
+ bbox=None,
93
+ retinamask=True,
94
+ target_height=960,
95
+ target_width=960,
96
+ ):
97
+ mask_sum = annotation.shape[0]
98
+ height = annotation.shape[1]
99
+ weight = annotation.shape[2]
100
+ # 将annotation 按照面积 排序
101
+ areas = np.sum(annotation, axis=(1, 2))
102
+ sorted_indices = np.argsort(areas)[::1]
103
+ annotation = annotation[sorted_indices]
104
+
105
+ index = (annotation != 0).argmax(axis=0)
106
+ if random_color:
107
+ color = np.random.random((mask_sum, 1, 1, 3))
108
+ else:
109
+ color = np.ones((mask_sum, 1, 1, 3)) * np.array([30 / 255, 144 / 255, 255 / 255])
110
+ transparency = np.ones((mask_sum, 1, 1, 1)) * 0.6
111
+ visual = np.concatenate([color, transparency], axis=-1)
112
+ mask_image = np.expand_dims(annotation, -1) * visual
113
+
114
+ mask = np.zeros((height, weight, 4))
115
+
116
+ h_indices, w_indices = np.meshgrid(np.arange(height), np.arange(weight), indexing='ij')
117
+ indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
118
+
119
+ mask[h_indices, w_indices, :] = mask_image[indices]
120
+ if bbox is not None:
121
+ x1, y1, x2, y2 = bbox
122
+ ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor='b', linewidth=1))
123
+
124
+ if not retinamask:
125
+ mask = cv2.resize(mask, (target_width, target_height), interpolation=cv2.INTER_NEAREST)
126
+
127
+ return mask
128
+
129
+
130
+ def fast_show_mask_gpu(
131
+ annotation,
132
+ ax,
133
+ random_color=False,
134
+ bbox=None,
135
+ retinamask=True,
136
+ target_height=960,
137
+ target_width=960,
138
+ ):
139
+ device = annotation.device
140
+ mask_sum = annotation.shape[0]
141
+ height = annotation.shape[1]
142
+ weight = annotation.shape[2]
143
+ areas = torch.sum(annotation, dim=(1, 2))
144
+ sorted_indices = torch.argsort(areas, descending=False)
145
+ annotation = annotation[sorted_indices]
146
+ # 找每个位置第一个非零值下标
147
+ index = (annotation != 0).to(torch.long).argmax(dim=0)
148
+ if random_color:
149
+ color = torch.rand((mask_sum, 1, 1, 3)).to(device)
150
+ else:
151
+ color = torch.ones((mask_sum, 1, 1, 3)).to(device) * torch.tensor(
152
+ [30 / 255, 144 / 255, 255 / 255]
153
+ ).to(device)
154
+ transparency = torch.ones((mask_sum, 1, 1, 1)).to(device) * 0.6
155
+ visual = torch.cat([color, transparency], dim=-1)
156
+ mask_image = torch.unsqueeze(annotation, -1) * visual
157
+ # 按index取数,index指每个位置选哪个batch的数,把mask_image转成一个batch的形式
158
+ mask = torch.zeros((height, weight, 4)).to(device)
159
+ h_indices, w_indices = torch.meshgrid(torch.arange(height), torch.arange(weight))
160
+ indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
161
+ # 使用向量化索引更新show的值
162
+ mask[h_indices, w_indices, :] = mask_image[indices]
163
+ mask_cpu = mask.cpu().numpy()
164
+ if bbox is not None:
165
+ x1, y1, x2, y2 = bbox
166
+ ax.add_patch(
167
+ plt.Rectangle(
168
+ (x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor="b", linewidth=1
169
+ )
170
+ )
171
+ if not retinamask:
172
+ mask_cpu = cv2.resize(
173
+ mask_cpu, (target_width, target_height), interpolation=cv2.INTER_NEAREST
174
+ )
175
+ return mask_cpu