hbazai commited on
Commit
a1687ef
·
verified ·
1 Parent(s): 8b4224e

Upload folder using huggingface_hub

Browse files
Files changed (5) hide show
  1. .gradio/certificate.pem +31 -0
  2. README.md +2 -8
  3. amg_paddle.py +308 -0
  4. promt_predict.py +128 -0
  5. text_to_sam_clip.py +241 -0
.gradio/certificate.pem ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ -----BEGIN CERTIFICATE-----
2
+ MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
3
+ TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
4
+ cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
5
+ WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
6
+ ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
7
+ MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
8
+ h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
9
+ 0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
10
+ A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
11
+ T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
12
+ B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
13
+ B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
14
+ KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
15
+ OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
16
+ jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
17
+ qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
18
+ rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
19
+ HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
20
+ hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
21
+ ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
22
+ 3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
23
+ NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
24
+ ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
25
+ TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
26
+ jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
27
+ oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
28
+ 4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
29
+ mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
30
+ emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
31
+ -----END CERTIFICATE-----
README.md CHANGED
@@ -1,12 +1,6 @@
1
  ---
2
- title: Segmentanything
3
- emoji: 🦀
4
- colorFrom: blue
5
- colorTo: yellow
6
  sdk: gradio
7
  sdk_version: 5.9.1
8
- app_file: app.py
9
- pinned: false
10
  ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: segmentanything
3
+ app_file: text_to_sam_clip.py
 
 
4
  sdk: gradio
5
  sdk_version: 5.9.1
 
 
6
  ---
 
 
amg_paddle.py ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # This implementation refers to: https://github.com/facebookresearch/segment-anything
16
+
17
+ import os
18
+ import sys
19
+ sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), ".."))
20
+
21
+ import time
22
+ import cv2 # type: ignore
23
+ import argparse
24
+ import numpy as np # type: ignore
25
+ import paddle
26
+
27
+ from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
28
+ from paddleseg.utils.visualize import get_pseudo_color_map, get_color_map_list
29
+
30
+ ID_PHOTO_IMAGE_DEMO = "examples/cityscapes_demo.png"
31
+ CACHE_DIR = ".temp"
32
+
33
+ model_link = {
34
+ 'vit_h':
35
+ "https://bj.bcebos.com/paddleseg/dygraph/paddlesegAnything/vit_h/model.pdparams",
36
+ 'vit_l':
37
+ "https://bj.bcebos.com/paddleseg/dygraph/paddlesegAnything/vit_l/model.pdparams",
38
+ 'vit_b':
39
+ "https://bj.bcebos.com/paddleseg/dygraph/paddlesegAnything/vit_b/model.pdparams",
40
+ 'vit_t':
41
+ "https://paddleseg.bj.bcebos.com/dygraph/paddlesegAnything/vit_t/model.pdparam"
42
+ }
43
+
44
+ parser = argparse.ArgumentParser(description=(
45
+ "Runs automatic mask generation on an input image or directory of images, "
46
+ "and outputs masks as either PNGs or COCO-style RLEs. Requires open-cv, "
47
+ "as well as pycocotools if saving in RLE format."))
48
+
49
+ parser.add_argument(
50
+ "--model-type",
51
+ type=str,
52
+ default="vit_l",
53
+ required=True,
54
+ help="The type of model to load, in ['vit_h', 'vit_l', 'vit_b', 'vit_t']", )
55
+
56
+ parser.add_argument(
57
+ "--convert-to-rle",
58
+ action="store_true",
59
+ help=(
60
+ "Save masks as COCO RLEs in a single json instead of as a folder of PNGs. "
61
+ "Requires pycocotools."), )
62
+
63
+ amg_settings = parser.add_argument_group("AMG Settings")
64
+
65
+ amg_settings.add_argument(
66
+ "--points-per-side",
67
+ type=int,
68
+ default=None,
69
+ help="Generate masks by sampling a grid over the image with this many points to a side.",
70
+ )
71
+
72
+ amg_settings.add_argument(
73
+ "--points-per-batch",
74
+ type=int,
75
+ default=None,
76
+ help="How many input points to process simultaneously in one batch.", )
77
+
78
+ amg_settings.add_argument(
79
+ "--pred-iou-thresh",
80
+ type=float,
81
+ default=None,
82
+ help="Exclude masks with a predicted score from the model that is lower than this threshold.",
83
+ )
84
+
85
+ amg_settings.add_argument(
86
+ "--stability-score-thresh",
87
+ type=float,
88
+ default=None,
89
+ help="Exclude masks with a stability score lower than this threshold.", )
90
+
91
+ amg_settings.add_argument(
92
+ "--stability-score-offset",
93
+ type=float,
94
+ default=None,
95
+ help="Larger values perturb the mask more when measuring stability score.",
96
+ )
97
+
98
+ amg_settings.add_argument(
99
+ "--box-nms-thresh",
100
+ type=float,
101
+ default=None,
102
+ help="The overlap threshold for excluding a duplicate mask.", )
103
+
104
+ amg_settings.add_argument(
105
+ "--crop-n-layers",
106
+ type=int,
107
+ default=None,
108
+ help=(
109
+ "If >0, mask generation is run on smaller crops of the image to generate more masks. "
110
+ "The value sets how many different scales to crop at."), )
111
+
112
+ amg_settings.add_argument(
113
+ "--crop-nms-thresh",
114
+ type=float,
115
+ default=None,
116
+ help="The overlap threshold for excluding duplicate masks across different crops.",
117
+ )
118
+
119
+ amg_settings.add_argument(
120
+ "--crop-overlap-ratio",
121
+ type=int,
122
+ default=None,
123
+ help="Larger numbers mean image crops will overlap more.", )
124
+
125
+ amg_settings.add_argument(
126
+ "--crop-n-points-downscale-factor",
127
+ type=int,
128
+ default=None,
129
+ help="The number of points-per-side in each layer of crop is reduced by this factor.",
130
+ )
131
+
132
+ amg_settings.add_argument(
133
+ "--min-mask-region-area",
134
+ type=int,
135
+ default=None,
136
+ help=(
137
+ "Disconnected mask regions or holes with area smaller than this value "
138
+ "in pixels are removed by postprocessing."), )
139
+
140
+
141
+ def get_amg_kwargs(args):
142
+ amg_kwargs = {
143
+ "points_per_side": args.points_per_side,
144
+ "points_per_batch": args.points_per_batch,
145
+ "pred_iou_thresh": args.pred_iou_thresh,
146
+ "stability_score_thresh": args.stability_score_thresh,
147
+ "stability_score_offset": args.stability_score_offset,
148
+ "box_nms_thresh": args.box_nms_thresh,
149
+ "crop_n_layers": args.crop_n_layers,
150
+ "crop_nms_thresh": args.crop_nms_thresh,
151
+ "crop_overlap_ratio": args.crop_overlap_ratio,
152
+ "crop_n_points_downscale_factor": args.crop_n_points_downscale_factor,
153
+ "min_mask_region_area": args.min_mask_region_area,
154
+ }
155
+ amg_kwargs = {k: v for k, v in amg_kwargs.items() if v is not None}
156
+ return amg_kwargs
157
+
158
+
159
+ def delete_result():
160
+ """clear old result in `.temp`"""
161
+ results = sorted(os.listdir(CACHE_DIR))
162
+ for res in results:
163
+ if int(time.time()) - int(os.path.splitext(res)[0]) > 10000:
164
+ os.remove(os.path.join(CACHE_DIR, res))
165
+
166
+
167
+ def download(img):
168
+ if not os.path.exists(CACHE_DIR):
169
+ os.makedirs(CACHE_DIR)
170
+ while True:
171
+ name = str(int(time.time()))
172
+ tmp_name = os.path.join(CACHE_DIR, name + '.jpg')
173
+ if not os.path.exists(tmp_name):
174
+ break
175
+ else:
176
+ time.sleep(1)
177
+
178
+ img.save(tmp_name, 'png')
179
+ return tmp_name
180
+
181
+
182
+ def masks2pseudomap(masks):
183
+ result = np.ones(masks[0]["segmentation"].shape, dtype=np.uint8) * 255
184
+ for i, mask_data in enumerate(masks):
185
+ result[mask_data["segmentation"] == 1] = i + 1
186
+ pred_result = result
187
+ result = get_pseudo_color_map(result)
188
+
189
+ return pred_result, result
190
+
191
+
192
+ def visualize(image, result, color_map, weight=0.6):
193
+ """
194
+ Convert predict result to color image, and save added image.
195
+
196
+ Args:
197
+ image (str): The path of origin image.
198
+ result (np.ndarray): The predict result of image.
199
+ color_map (list): The color used to save the prediction results.
200
+ save_dir (str): The directory for saving visual image. Default: None.
201
+ weight (float): The image weight of visual image, and the result weight is (1 - weight). Default: 0.6
202
+
203
+ Returns:
204
+ vis_result (np.ndarray): If `save_dir` is None, return the visualized result.
205
+ """
206
+
207
+ color_map = [color_map[i:i + 3] for i in range(0, len(color_map), 3)]
208
+ color_map = np.array(color_map).astype("uint8")
209
+ # Use OpenCV LUT for color mapping
210
+ c1 = cv2.LUT(result, color_map[:, 0])
211
+ c2 = cv2.LUT(result, color_map[:, 1])
212
+ c3 = cv2.LUT(result, color_map[:, 2])
213
+ pseudo_img = np.dstack((c3, c2, c1))
214
+
215
+ # im = cv2.imread(image)
216
+ vis_result = cv2.addWeighted(image, weight, pseudo_img, 1 - weight, 0)
217
+ return vis_result
218
+
219
+
220
+ def gradio_display(generator):
221
+ import gradio as gr
222
+
223
+ def clear_image_all():
224
+ delete_result()
225
+ return None, None, None, None
226
+
227
+ def get_id_photo_output(img):
228
+ """
229
+ Get the special size and background photo.
230
+
231
+ Args:
232
+ img(numpy:ndarray): The image array.
233
+ size(str): The size user specified.
234
+ bg(str): The background color user specified.
235
+ download_size(str): The size for image saving.
236
+
237
+ """
238
+ predictor = generator
239
+ masks = predictor.generate(img)
240
+ pred_result, pseudo_map = masks2pseudomap(masks) # PIL Image
241
+ added_pseudo_map = visualize(
242
+ img, pred_result, color_map=get_color_map_list(256))
243
+ res_download = download(pseudo_map)
244
+
245
+ return pseudo_map, added_pseudo_map, res_download
246
+
247
+ with gr.Blocks() as demo:
248
+ gr.Markdown("""# Segment Anything (PaddleSeg) """)
249
+ with gr.Tab("InputImage"):
250
+ image_in = gr.Image(value=ID_PHOTO_IMAGE_DEMO, label="Input image")
251
+
252
+ with gr.Row():
253
+ image_clear_btn = gr.Button("Clear")
254
+ image_submit_btn = gr.Button("Submit")
255
+
256
+ with gr.Row():
257
+ img_out1 = gr.Image(
258
+ label="Output image", interactive=False).style(height=300)
259
+ img_out2 = gr.Image(
260
+ label="Output image with mask",
261
+ interactive=False).style(height=300)
262
+ downloaded_img = gr.File(label='Image download').style(height=50)
263
+
264
+ image_clear_btn.click(
265
+ fn=clear_image_all,
266
+ inputs=None,
267
+ outputs=[image_in, img_out1, img_out2, downloaded_img])
268
+
269
+ image_submit_btn.click(
270
+ fn=get_id_photo_output,
271
+ inputs=[image_in, ],
272
+ outputs=[img_out1, img_out2, downloaded_img])
273
+
274
+ gr.Markdown(
275
+ """<font color=Gray>Tips: You can try segment the default image OR upload any images you want to segment by click on the clear button first.</font>"""
276
+ )
277
+
278
+ gr.Markdown(
279
+ """<font color=Gray>This is Segment Anything build with PaddlePaddle.
280
+ We refer to the [SAM](https://github.com/facebookresearch/segment-anything) for code strucure and model architecture.
281
+ If you have any question or feature request, welcome to raise issues on [GitHub](https://github.com/PaddlePaddle/PaddleSeg/issues). </font>"""
282
+ )
283
+
284
+ gr.Button.style(1)
285
+
286
+ demo.launch(server_name="0.0.0.0", server_port=8017, share=True)
287
+
288
+
289
+ def main(args: argparse.Namespace) -> None:
290
+ print("Loading model...")
291
+
292
+ sam = sam_model_registry[args.model_type](
293
+ checkpoint=model_link[args.model_type])
294
+ if paddle.is_compiled_with_cuda():
295
+ paddle.set_device("gpu")
296
+ else:
297
+ paddle.set_device("cpu")
298
+ output_mode = "coco_rle" if args.convert_to_rle else "binary_mask"
299
+ amg_kwargs = get_amg_kwargs(args)
300
+ generator = SamAutomaticMaskGenerator(
301
+ sam, output_mode=output_mode, **amg_kwargs)
302
+
303
+ gradio_display(generator)
304
+
305
+
306
+ if __name__ == "__main__":
307
+ args = parser.parse_args()
308
+ main(args)
promt_predict.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # This implementation refers to: https://github.com/facebookresearch/segment-anything
16
+
17
+ import os
18
+ import sys
19
+ import argparse
20
+ sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), ".."))
21
+
22
+ import paddle
23
+ import cv2
24
+ import numpy as np
25
+ import matplotlib.pyplot as plt
26
+
27
+ from segment_anything.predictor import SamPredictor
28
+ from segment_anything.build_sam import sam_model_registry
29
+
30
+ model_link = {
31
+ 'vit_h':
32
+ "https://bj.bcebos.com/paddleseg/dygraph/paddlesegAnything/vit_h/model.pdparams",
33
+ 'vit_l':
34
+ "https://bj.bcebos.com/paddleseg/dygraph/paddlesegAnything/vit_l/model.pdparams",
35
+ 'vit_b':
36
+ "https://bj.bcebos.com/paddleseg/dygraph/paddlesegAnything/vit_b/model.pdparams",
37
+ 'vit_t':
38
+ "https://paddleseg.bj.bcebos.com/dygraph/paddlesegAnything/vit_t/model.pdparam"
39
+ }
40
+
41
+
42
+ def get_args():
43
+ parser = argparse.ArgumentParser(
44
+ description='Segment image with point promp or box')
45
+ # Parameters
46
+ parser.add_argument(
47
+ '--input_path', type=str, required=True, help='The directory of image.')
48
+ parser.add_argument(
49
+ "--model-type",
50
+ type=str,
51
+ default="vit_l",
52
+ required=True,
53
+ help="The type of model to load, in ['vit_h', 'vit_l', 'vit_b', 'vit_t']",
54
+ )
55
+ parser.add_argument(
56
+ '--point_prompt',
57
+ type=int,
58
+ nargs='+',
59
+ default=None,
60
+ help='point prompt.')
61
+ parser.add_argument(
62
+ '--box_prompt',
63
+ type=int,
64
+ nargs='+',
65
+ default=None,
66
+ help='box prompt format as xyxy.')
67
+ parser.add_argument(
68
+ '--output_path',
69
+ type=str,
70
+ default='./output/',
71
+ help='The directory for saving the results')
72
+ return parser.parse_args()
73
+
74
+
75
+ def show_mask(mask, ax, random_color=False):
76
+ if random_color:
77
+ color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
78
+ else:
79
+ color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
80
+ h, w = mask.shape[-2:]
81
+ mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
82
+ ax.imshow(mask_image)
83
+
84
+
85
+ def main(args):
86
+ if paddle.is_compiled_with_cuda():
87
+ paddle.set_device("gpu")
88
+ else:
89
+ paddle.set_device("cpu")
90
+ input_path = args.input_path
91
+ output_path = args.output_path
92
+ point, box = args.point_prompt, args.box_prompt
93
+ if point is not None:
94
+ point = np.array([point])
95
+ input_label = np.array([1])
96
+ else:
97
+ input_label = None
98
+ if box is not None:
99
+ box = np.array([[box[0], box[1]], [box[2], box[3]]])
100
+
101
+ image = cv2.imread(input_path)
102
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
103
+ model = sam_model_registry[args.model_type](
104
+ checkpoint=model_link[args.model_type])
105
+ predictor = SamPredictor(model)
106
+ predictor.set_image(image)
107
+
108
+ masks, _, _ = predictor.predict(
109
+ point_coords=point,
110
+ point_labels=input_label,
111
+ box=box,
112
+ multimask_output=True, )
113
+
114
+ plt.figure(figsize=(10, 10))
115
+ plt.imshow(image)
116
+ show_mask(masks[0], plt.gca())
117
+ plt.axis('off')
118
+ basename = os.path.basename(input_path)
119
+ if not os.path.exists(output_path):
120
+ os.makedirs(output_path)
121
+ path_output = os.path.join(output_path, basename)
122
+ plt.savefig(path_output)
123
+ print('The output has been saved to {}'.format(path_output))
124
+
125
+
126
+ if __name__ == "__main__":
127
+ args = get_args()
128
+ main(args)
text_to_sam_clip.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ import cv2
17
+ import time
18
+ import sys
19
+ import argparse
20
+ sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), ".."))
21
+
22
+ import paddle
23
+ import paddle.nn.functional as F
24
+ import numpy as np
25
+ from PIL import Image, ImageDraw
26
+ import matplotlib.pyplot as plt
27
+
28
+ from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
29
+ from segment_anything.modeling.clip_paddle import build_clip_model, _transform
30
+ from segment_anything.utils.sample_tokenizer import tokenize
31
+ from paddleseg.utils.visualize import get_pseudo_color_map, get_color_map_list
32
+
33
+ ID_PHOTO_IMAGE_DEMO = "./examples/cityscapes_demo.png"
34
+ CACHE_DIR = ".temp"
35
+ model_link = {
36
+ 'vit_h':
37
+ "https://bj.bcebos.com/paddleseg/dygraph/paddlesegAnything/vit_h/model.pdparams",
38
+ 'vit_l':
39
+ "https://bj.bcebos.com/paddleseg/dygraph/paddlesegAnything/vit_l/model.pdparams",
40
+ 'vit_b':
41
+ "https://bj.bcebos.com/paddleseg/dygraph/paddlesegAnything/vit_b/model.pdparams",
42
+ 'vit_t':
43
+ "https://paddleseg.bj.bcebos.com/dygraph/paddlesegAnything/vit_t/model.pdparam",
44
+ 'clip_b_32':
45
+ "https://bj.bcebos.com/paddleseg/dygraph/clip/vit_b_32_pretrain/clip_vit_b_32.pdparams"
46
+ }
47
+
48
+ parser = argparse.ArgumentParser(description=(
49
+ "Runs automatic mask generation on an input image or directory of images, "
50
+ "and outputs masks as either PNGs or COCO-style RLEs. Requires open-cv, "
51
+ "as well as pycocotools if saving in RLE format."))
52
+
53
+ parser.add_argument(
54
+ "--model-type",
55
+ type=str,
56
+ default="vit_h",
57
+ required=True,
58
+ help="The type of model to load, in ['vit_h', 'vit_l', 'vit_b', 'vit_t']", )
59
+
60
+
61
+ def download(img):
62
+ if not os.path.exists(CACHE_DIR):
63
+ os.makedirs(CACHE_DIR)
64
+ while True:
65
+ name = str(int(time.time()))
66
+ tmp_name = os.path.join(CACHE_DIR, name + '.jpg')
67
+ if not os.path.exists(tmp_name):
68
+ break
69
+ else:
70
+ time.sleep(1)
71
+ img.save(tmp_name, 'png')
72
+ return tmp_name
73
+
74
+
75
+ def segment_image(image, segment_mask):
76
+ image_array = np.array(image)
77
+ gray_image = Image.new("RGB", image.size, (128, 128, 128))
78
+ segmented_image_array = np.zeros_like(image_array)
79
+ segmented_image_array[segment_mask] = image_array[segment_mask]
80
+ segmented_image = Image.fromarray(segmented_image_array)
81
+ transparency = np.zeros_like(segment_mask, dtype=np.uint8)
82
+ transparency[segment_mask] = 255
83
+ transparency_image = Image.fromarray(transparency, mode='L')
84
+ gray_image.paste(segmented_image, mask=transparency_image)
85
+ return gray_image
86
+
87
+
88
+ def image_text_match(cropped_objects, text_query):
89
+ transformed_images = [transform(image) for image in cropped_objects]
90
+ tokenized_text = tokenize([text_query])
91
+ batch_images = paddle.stack(transformed_images)
92
+ image_features = model.encode_image(batch_images)
93
+ print("encode_image done!")
94
+ text_features = model.encode_text(tokenized_text)
95
+ print("encode_text done!")
96
+ image_features /= image_features.norm(axis=-1, keepdim=True)
97
+ text_features /= text_features.norm(axis=-1, keepdim=True)
98
+ if len(text_features.shape) == 3:
99
+ text_features = text_features.squeeze(0)
100
+ probs = 100. * image_features @text_features.T
101
+ return F.softmax(probs[:, 0], axis=0)
102
+
103
+
104
+ def masks2pseudomap(masks):
105
+ result = np.ones(masks[0]["segmentation"].shape, dtype=np.uint8) * 255
106
+ for i, mask_data in enumerate(masks):
107
+ result[mask_data["segmentation"] == 1] = i + 1
108
+ pred_result = result
109
+ result = get_pseudo_color_map(result)
110
+ return pred_result, result
111
+
112
+
113
+ def visualize(image, result, color_map, weight=0.6):
114
+ """
115
+ Convert predict result to color image, and save added image.
116
+
117
+ Args:
118
+ image (str): The path of origin image.
119
+ result (np.ndarray): The predict result of image.
120
+ color_map (list): The color used to save the prediction results.
121
+ save_dir (str): The directory for saving visual image. Default: None.
122
+ weight (float): The image weight of visual image, and the result weight is (1 - weight). Default: 0.6
123
+
124
+ Returns:
125
+ vis_result (np.ndarray): If `save_dir` is None, return the visualized result.
126
+ """
127
+
128
+ color_map = [color_map[i:i + 3] for i in range(0, len(color_map), 3)]
129
+ color_map = np.array(color_map).astype("uint8")
130
+ # Use OpenCV LUT for color mapping
131
+ c1 = cv2.LUT(result, color_map[:, 0])
132
+ c2 = cv2.LUT(result, color_map[:, 1])
133
+ c3 = cv2.LUT(result, color_map[:, 2])
134
+ pseudo_img = np.dstack((c3, c2, c1))
135
+
136
+ vis_result = cv2.addWeighted(image, weight, pseudo_img, 1 - weight, 0)
137
+ return vis_result
138
+
139
+
140
+ def get_id_photo_output(image, text):
141
+ """
142
+ Get the special size and background photo.
143
+
144
+ Args:
145
+ img(numpy:ndarray): The image array.
146
+ size(str): The size user specified.
147
+ bg(str): The background color user specified.
148
+ download_size(str): The size for image saving.
149
+
150
+ """
151
+ image_ori = image.copy()
152
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
153
+ masks = mask_generator.generate(image)
154
+ pred_result, pseudo_map = masks2pseudomap(masks) # PIL Image
155
+ added_pseudo_map = visualize(
156
+ image, pred_result, color_map=get_color_map_list(256))
157
+ cropped_objects = []
158
+ image_pil = Image.fromarray(image)
159
+ for mask in masks:
160
+ bbox = [
161
+ mask["bbox"][0], mask["bbox"][1], mask["bbox"][0] + mask["bbox"][2],
162
+ mask["bbox"][1] + mask["bbox"][3]
163
+ ]
164
+ cropped_objects.append(
165
+ segment_image(image_pil, mask["segmentation"]).crop(bbox))
166
+
167
+ scores = image_text_match(cropped_objects, str(text))
168
+ text_matching_masks = []
169
+ for idx, score in enumerate(scores):
170
+ if score < 0.05:
171
+ continue
172
+ text_matching_mask = Image.fromarray(
173
+ masks[idx]["segmentation"].astype('uint8') * 255)
174
+ text_matching_masks.append(text_matching_mask)
175
+
176
+ image_pil_ori = Image.fromarray(image_ori)
177
+ alpha_image = Image.new('RGBA', image_pil_ori.size, (0, 0, 0, 0))
178
+ alpha_color = (255, 0, 0, 180)
179
+
180
+ draw = ImageDraw.Draw(alpha_image)
181
+ for text_matching_mask in text_matching_masks:
182
+ draw.bitmap((0, 0), text_matching_mask, fill=alpha_color)
183
+
184
+ result_image = Image.alpha_composite(
185
+ image_pil_ori.convert('RGBA'), alpha_image)
186
+ res_download = download(result_image)
187
+ return result_image, added_pseudo_map, res_download
188
+
189
+
190
+ def gradio_display():
191
+ import gradio as gr
192
+ examples_sam = [["./examples/cityscapes_demo.png", "a photo of car"],
193
+ ["examples/dog.jpg", "dog"],
194
+ ["examples/zixingche.jpeg", "kid"]]
195
+
196
+ demo_mask_sam = gr.Interface(
197
+ fn=get_id_photo_output,
198
+ inputs=[
199
+ gr.Image(label="Input image", height=400),
200
+ gr.Textbox(label="Input text prompt", value="a car"),
201
+ ],
202
+ outputs=[
203
+ gr.Image(label="Output based on text", height=300),
204
+ gr.Image(label="Output mask", height=300)
205
+ ],
206
+ examples=examples_sam,
207
+ description="<p> \
208
+ <strong>SAM+CLIP: Text prompt for segmentation. </strong> <br>\
209
+ Choose an example below; Or, upload by yourself: <br>\
210
+ 1. Upload images to be tested to 'input image'. 2. Input a text prompt to 'input text prompt' and click 'submit'</strong>. <br>\
211
+ </p>",
212
+ cache_examples=False,
213
+ flagging_mode="never"
214
+ )
215
+
216
+ demo = gr.TabbedInterface(
217
+ [demo_mask_sam],
218
+ ['SAM+CLIP(Text to Segment)'],
219
+ title=" 🔥 Text to Segment Anything with PaddleSeg 🔥"
220
+ )
221
+
222
+ demo.launch(
223
+ server_name="0.0.0.0",
224
+ server_port=8078,
225
+ share=True
226
+ )
227
+
228
+ args = parser.parse_args()
229
+ print("Loading model...")
230
+
231
+ if paddle.is_compiled_with_cuda():
232
+ paddle.set_device("gpu")
233
+ else:
234
+ paddle.set_device("cpu")
235
+
236
+ sam = sam_model_registry[args.model_type](
237
+ checkpoint=model_link[args.model_type])
238
+ mask_generator = SamAutomaticMaskGenerator(sam)
239
+
240
+ model, transform = build_clip_model(model_link["clip_b_32"])
241
+ gradio_display()