jackyccl commited on
Commit
6930c1c
·
1 Parent(s): 1c7b1a7

Add Inpainting Function

Browse files
Files changed (1) hide show
  1. app.py +175 -44
app.py CHANGED
@@ -26,12 +26,18 @@ import groundingdino.datasets.transforms as T
26
  # segment anything
27
  from segment_anything import build_sam, SamPredictor
28
 
 
 
 
29
  from huggingface_hub import hf_hub_download
30
 
 
 
 
31
  if not os.path.exists('./sam_vit_h_4b8939.pth'):
32
  logger.info(f"get sam_vit_h_4b8939.pth...")
33
  result = subprocess.run(['wget', 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth'], check=True)
34
- print(f'wget sam_vit_h_4b8939.pth result = {result}')
35
 
36
  # Use this command for evaluate the GLIP-T model
37
  config_file = "groundingdino/config/GroundingDINO_SwinT_OGC.py"
@@ -112,7 +118,7 @@ def plot_boxes_to_image(image_pil, tgt):
112
  # bbox = draw.textbbox((x0, y0), str(label))
113
  draw.rectangle(bbox, fill=color)
114
  font = os.path.join(cv2.__path__[0],'qt','fonts','DejaVuSans.ttf')
115
- font_size = 36
116
  new_font = ImageFont.truetype(font, font_size)
117
 
118
  draw.text((x0+2, y0+2), str(label), font=new_font, fill="white")
@@ -133,8 +139,8 @@ def show_mask(mask, ax, random_color=False):
133
  def show_box(box, ax, label):
134
  x0, y0 = box[0], box[1]
135
  w, h = box[2] - box[0], box[3] - box[1]
136
- ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='red', facecolor=(0,0,0,0), lw=2))
137
- ax.text(x0, y0+20, label, fontdict={'fontsize': 10}, color="white")
138
 
139
  def get_grounding_box(image_tensor, grounding_caption, box_threshold, text_threshold):
140
  # run grounding
@@ -148,7 +154,32 @@ def get_grounding_box(image_tensor, grounding_caption, box_threshold, text_thres
148
  # image_with_box = Image.fromarray(cv2.cvtColor(annotated_frame, cv2.COLOR_BGR2RGB))
149
  return boxes, labels
150
 
151
- def grounding_sam(input_image, text_prompt, task_type, box_threshold, text_threshold, iou_threshold):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
  text_prompt = text_prompt.strip()
153
 
154
  # user guidance messages
@@ -160,31 +191,45 @@ def grounding_sam(input_image, text_prompt, task_type, box_threshold, text_thres
160
  return [], gr.Gallery.update(label='Please upload a image~~')
161
 
162
  file_temp = int(time.time())
 
 
 
 
 
 
163
  image_pil, image_tensor = load_image_and_transform(input_image['image'])
164
 
165
- # get dino bounding boxes
166
- boxes, phrases = get_grounding_box(image_tensor, text_prompt, box_threshold, text_threshold)
167
- if boxes.size(0) == 0:
168
- logger.info(f'run_grounded_sam_[]_{task_type}_[{text_prompt}]_1_[No objects detected, please try others.]_')
169
- return [], gr.Gallery.update(label='No objects detected, please try others!')
170
-
171
- size = image_pil.size
172
- pred_dict = {
173
- "boxes": boxes,
174
- "size": [size[1], size[0]], # H,W
175
- "labels": phrases,
176
- }
177
-
178
- # store and save dino output
179
- output_images = []
180
- image_with_box = plot_boxes_to_image(copy.deepcopy(image_pil), pred_dict)[0]
181
- image_path = os.path.join(output_dir, f"grounding_dino_output_{file_temp}.jpg")
182
- image_with_box.save(image_path)
183
- detection_image_result = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
184
- os.remove(image_path)
185
- output_images.append(detection_image_result)
186
 
187
- if task_type == 'segment':
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188
  image = np.array(input_image['image'])
189
  sam_predictor.set_image(image)
190
 
@@ -223,14 +268,83 @@ def grounding_sam(input_image, text_prompt, task_type, box_threshold, text_thres
223
  segment_image_result = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
224
  os.remove(image_path)
225
  output_images.append(segment_image_result)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
 
227
- return output_images, gr.Gallery.update(label='result images')
228
 
 
 
 
229
  groundingDino_model = load_model_hf(config_file, ckpt_repo_id, ckpt_filename, groundingdino_device)
230
  sam_predictor = SamPredictor(build_sam(checkpoint=sam_checkpoint))
231
 
 
 
 
 
 
 
 
 
 
 
232
  if __name__ == "__main__":
233
-
 
 
 
234
  parser = argparse.ArgumentParser("Grounding SAM demo", add_help=True)
235
  parser.add_argument("--debug", action="store_true", help="using debug mode")
236
  parser.add_argument("--share", action="store_true", help="share the app")
@@ -240,15 +354,22 @@ if __name__ == "__main__":
240
 
241
  block = gr.Blocks().queue()
242
  with block:
243
- gr.Markdown("# GroundingDino and SAM")
244
  with gr.Row():
245
  with gr.Column():
246
- input_image = gr.Image(source='upload', elem_id="image_upload", tool='sketch', type='pil', label="Upload")
247
- task_type = gr.Radio(["segment"], value="segment",
248
- label='Task type',interactive=True, visible=True)
249
- text_prompt = gr.Textbox(label="Detection Prompt, seperating each name with dot '.', i.e.: cat.dog.chair ]", \
250
- placeholder="Cannot be empty")
251
-
 
 
 
 
 
 
 
252
  run_button = gr.Button(label="Run")
253
  with gr.Accordion("Advanced options", open=False):
254
  box_threshold = gr.Slider(
@@ -259,18 +380,28 @@ if __name__ == "__main__":
259
  )
260
  iou_threshold = gr.Slider(
261
  label="IOU Threshold", minimum=0.0, maximum=1.0, value=0.8, step=0.001
262
- )
 
 
 
 
 
 
263
 
264
  with gr.Column():
265
- gallery = gr.Gallery(
266
- label="result images", show_label=True, elem_id="gallery"
267
- ).style(grid=[2], full_width=True, full_height=True)
268
- # gallery = gr.Gallery(label="Generated images", show_label=False).style(
269
- # grid=[1], height="auto", container=True, full_width=True, full_height=True)
270
 
271
  DESCRIPTION = '### This demo from [Grounded-Segment-Anything](https://github.com/IDEA-Research/Grounded-Segment-Anything) and kudos to thier excellent works. Welcome everyone to try this out and learn together!'
272
  gr.Markdown(DESCRIPTION)
273
- run_button.click(fn=grounding_sam, inputs=[
274
- input_image, text_prompt, task_type, box_threshold, text_threshold, iou_threshold], outputs=[gallery, gallery])
 
 
 
 
275
 
276
  block.launch(debug=args.debug, share=args.share, show_api=False, show_error=True)
 
26
  # segment anything
27
  from segment_anything import build_sam, SamPredictor
28
 
29
+ #stable diffusion
30
+ from diffusers import StableDiffusionInpaintPipeline
31
+
32
  from huggingface_hub import hf_hub_download
33
 
34
+ if not os.path.exists('./demo2.jpg'):
35
+ os.system("wget https://github.com/IDEA-Research/Grounded-Segment-Anything/raw/main/assets/demo2.jpg")
36
+
37
  if not os.path.exists('./sam_vit_h_4b8939.pth'):
38
  logger.info(f"get sam_vit_h_4b8939.pth...")
39
  result = subprocess.run(['wget', 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth'], check=True)
40
+ print(f'wget sam_vit_h_4b8939.pth result = {result}')
41
 
42
  # Use this command for evaluate the GLIP-T model
43
  config_file = "groundingdino/config/GroundingDINO_SwinT_OGC.py"
 
118
  # bbox = draw.textbbox((x0, y0), str(label))
119
  draw.rectangle(bbox, fill=color)
120
  font = os.path.join(cv2.__path__[0],'qt','fonts','DejaVuSans.ttf')
121
+ font_size = 20
122
  new_font = ImageFont.truetype(font, font_size)
123
 
124
  draw.text((x0+2, y0+2), str(label), font=new_font, fill="white")
 
139
  def show_box(box, ax, label):
140
  x0, y0 = box[0], box[1]
141
  w, h = box[2] - box[0], box[3] - box[1]
142
+ ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='red', facecolor=(0,0,0,0), lw=1))
143
+ ax.text(x0, y0+20, label, fontdict={'fontsize': 6}, color="white")
144
 
145
  def get_grounding_box(image_tensor, grounding_caption, box_threshold, text_threshold):
146
  # run grounding
 
154
  # image_with_box = Image.fromarray(cv2.cvtColor(annotated_frame, cv2.COLOR_BGR2RGB))
155
  return boxes, labels
156
 
157
+ def mask_extend(img, box, extend_pixels=10, useRectangle=True):
158
+ box[0] = int(box[0])
159
+ box[1] = int(box[1])
160
+ box[2] = int(box[2])
161
+ box[3] = int(box[3])
162
+ region = img.crop(tuple(box)) # crop based on bb box
163
+ new_width = box[2] - box[0] + 2*extend_pixels
164
+ new_height = box[3] - box[1] + 2*extend_pixels
165
+
166
+ region_BILINEAR = region.resize((int(new_width), int(new_height))) # resize the cropped region based on "extend_pixels"
167
+ if useRectangle:
168
+ region_draw = ImageDraw.Draw(region_BILINEAR)
169
+ region_draw.rectangle((0, 0, new_width, new_height), fill=(255, 255, 255)) # draw white rectangle
170
+ img.paste(region_BILINEAR, (int(box[0]-extend_pixels), int(box[1]-extend_pixels))) #pastes the resized region back into the original image at the same location as the original bounding box but with an additional padding of extend_pixels pixels on all sides
171
+ return img
172
+
173
+ def mix_masks(imgs):
174
+ re_img = 1 - np.asarray(imgs[0].convert("1"))
175
+ for i in range(len(imgs)-1):
176
+ re_img = np.multiply(re_img, 1 - np.asarray(imgs[i+1].convert("1")))
177
+ re_img = 1 - re_img
178
+ return Image.fromarray(np.uint8(255*re_img))
179
+
180
+ def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_threshold, text_threshold,
181
+ iou_threshold, inpaint_mode, mask_source_radio, remove_mode, remove_mask_extend):
182
+
183
  text_prompt = text_prompt.strip()
184
 
185
  # user guidance messages
 
191
  return [], gr.Gallery.update(label='Please upload a image~~')
192
 
193
  file_temp = int(time.time())
194
+
195
+ # load mask
196
+ input_mask_pil = input_image['mask']
197
+ input_mask = np.array(input_mask_pil.convert("L"))
198
+
199
+ # load image
200
  image_pil, image_tensor = load_image_and_transform(input_image['image'])
201
 
202
+ # RUN GROUNDINGDINO: we skip DINO if we draw mask on the image
203
+ if (task_type == 'inpainting' or task_type == 'remove') and mask_source_radio == mask_source_draw:
204
+ pass
205
+ else:
206
+ boxes, phrases = get_grounding_box(image_tensor, text_prompt, box_threshold, text_threshold)
207
+ if boxes.size(0) == 0:
208
+ logger.info(f'run_grounded_sam_[]_{task_type}_[{text_prompt}]_1_[No objects detected, please try others.]_')
209
+ return [], gr.Gallery.update(label='No objects detected, please try others!')
210
+ boxes_filt_ori = copy.deepcopy(boxes)
 
 
 
 
 
 
 
 
 
 
 
 
211
 
212
+ size = image_pil.size
213
+
214
+ pred_dict = {
215
+ "boxes": boxes,
216
+ "size": [size[1], size[0]], # H,W
217
+ "labels": phrases,
218
+ }
219
+
220
+ # store and save DINO output
221
+ output_images = []
222
+ image_with_box = plot_boxes_to_image(copy.deepcopy(image_pil), pred_dict)[0]
223
+ image_path = os.path.join(output_dir, f"grounding_dino_output_{file_temp}.jpg")
224
+ image_with_box.save(image_path)
225
+ detection_image_result = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
226
+ os.remove(image_path)
227
+ output_images.append(detection_image_result)
228
+
229
+ # if mask is detected from DINO
230
+ logger.info(f'run_anything_task_[{file_temp}]_{task_type}_2_')
231
+ if task_type == 'segment' or ((task_type == 'inpainting' or task_type == 'remove')
232
+ and mask_source_radio == mask_source_segment):
233
  image = np.array(input_image['image'])
234
  sam_predictor.set_image(image)
235
 
 
268
  segment_image_result = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
269
  os.remove(image_path)
270
  output_images.append(segment_image_result)
271
+
272
+ logger.info(f'run_anything_task_[{file_temp}]_{task_type}_3_')
273
+ if task_type == 'segment':
274
+ logger.info(f'run_anything_task_[{file_temp}]_{task_type}_Final_')
275
+ return output_images, gr.Gallery.update(label='result images')
276
+
277
+ elif task_type == 'inpainting' or task_type == 'remove':
278
+ # if no inpaint prompt is entered, we treat it as remove
279
+ if inpaint_prompt.strip() == '' and mask_source_radio == mask_source_segment:
280
+ task_type = 'remove'
281
+
282
+ logger.info(f'run_anything_task_[{file_temp}]_{task_type}_4_')
283
+ if mask_source_radio == mask_source_draw:
284
+ mask_pil = input_mask_pil
285
+ mask = input_mask
286
+ else:
287
+ masks_ori = copy.deepcopy(masks)
288
+ # inpainting pipeline
289
+ if inpaint_mode == 'merge':
290
+ masks = torch.sum(masks, dim=0).unsqueeze(0)
291
+ masks = torch.where(masks > 0, True, False)
292
+
293
+ # simply choose the first mask, which will be refine in the future release
294
+ mask = masks[0][0].cpu().numpy()
295
+ mask_pil = Image.fromarray(mask)
296
+ output_images.append(mask_pil.convert("RGB"))
297
+
298
+ if task_type == 'inpainting':
299
+ # inpainting pipeline
300
+ image_source_for_inpaint = image_pil.resize((512, 512))
301
+ image_mask_for_inpaint = mask_pil.resize((512, 512))
302
+ image_inpainting = sd_pipe(prompt=inpaint_prompt, image=image_source_for_inpaint, mask_image=image_mask_for_inpaint).images[0]
303
+
304
+ image_inpainting = image_inpainting.resize((image_pil.size[0], image_pil.size[1]))
305
+ output_images.append(image_inpainting)
306
+ return output_images, gr.Gallery.update(label='result images')
307
+ else:
308
+ logger.info(f"task_type:{task_type} error!")
309
+ logger.info(f'run_anything_task_[{file_temp}]_Final_Inpainting_')
310
+ return output_images, gr.Gallery.update(label='result images')
311
+
312
+
313
+ def change_radio_display(task_type, mask_source_radio):
314
+ text_prompt_visible = True
315
+ inpaint_prompt_visible = False
316
+ mask_source_radio_visible = False
317
+
318
+ if task_type == "inpainting":
319
+ inpaint_prompt_visible = True
320
+ if task_type == "inpainting" or task_type == "remove":
321
+ mask_source_radio_visible = True
322
+ if mask_source_radio == mask_source_draw:
323
+ text_prompt_visible = False
324
 
325
+ return gr.Textbox.update(visible=text_prompt_visible), gr.Textbox.update(visible=inpaint_prompt_visible), gr.Radio.update(visible=mask_source_radio_visible)
326
 
327
+
328
+
329
+ # model initialization
330
  groundingDino_model = load_model_hf(config_file, ckpt_repo_id, ckpt_filename, groundingdino_device)
331
  sam_predictor = SamPredictor(build_sam(checkpoint=sam_checkpoint))
332
 
333
+ # initialize stable-diffusion-inpainting
334
+ logger.info(f"initialize stable-diffusion-inpainting...")
335
+ sd_pipe = None
336
+ if os.environ.get('IS_MY_DEBUG') is None:
337
+ sd_pipe = StableDiffusionInpaintPipeline.from_pretrained(
338
+ "runwayml/stable-diffusion-inpainting",
339
+ torch_dtype=torch.float16
340
+ )
341
+ sd_pipe = sd_pipe.to(device)
342
+
343
  if __name__ == "__main__":
344
+
345
+ mask_source_draw = "Draw mask on image."
346
+ mask_source_segment = "Segment based on prompt and inpaint."
347
+
348
  parser = argparse.ArgumentParser("Grounding SAM demo", add_help=True)
349
  parser.add_argument("--debug", action="store_true", help="using debug mode")
350
  parser.add_argument("--share", action="store_true", help="share the app")
 
354
 
355
  block = gr.Blocks().queue()
356
  with block:
357
+ gr.Markdown("# GroundingDino SAM and Stable Diffusion")
358
  with gr.Row():
359
  with gr.Column():
360
+ input_image = gr.Image(
361
+ source="upload", elem_id="image_upload", type="pil", tool="sketch", value="demo2.jpg", label="Upload")
362
+ task_type = gr.Radio(["segment", "inpainting", "remove"], value="segment",
363
+ label='Task type', visible=True)
364
+
365
+ mask_source_radio = gr.Radio([mask_source_draw, mask_source_segment],
366
+ value=mask_source_segment, label="Mask from",
367
+ visible=False)
368
+
369
+ text_prompt = gr.Textbox(label="Detection Prompt, seperating each name with dot '.', i.e.: bear.cat.dog.chair ]", \
370
+ value='bear', placeholder="Cannot be empty")
371
+ inpaint_prompt = gr.Textbox(label="Inpaint Prompt (if this is empty, then remove)", visible=False)
372
+
373
  run_button = gr.Button(label="Run")
374
  with gr.Accordion("Advanced options", open=False):
375
  box_threshold = gr.Slider(
 
380
  )
381
  iou_threshold = gr.Slider(
382
  label="IOU Threshold", minimum=0.0, maximum=1.0, value=0.8, step=0.001
383
+ )
384
+ inpaint_mode = gr.Radio(["merge", "first"], value="merge", label="inpaint_mode")
385
+ with gr.Row():
386
+ with gr.Column(scale=1):
387
+ remove_mode = gr.Radio(["segment", "rectangle"], value="segment", label='remove mode')
388
+ with gr.Column(scale=1):
389
+ remove_mask_extend = gr.Textbox(label="remove_mask_extend", value='10')
390
 
391
  with gr.Column():
392
+ gallery = gr.Gallery(label="result images", show_label=True, elem_id="gallery", visible=True
393
+ ).style(preview=True, columns=[5], object_fit="scale-down", height="auto")
394
+
395
+ task_type.change(fn=change_radio_display, inputs=[task_type, mask_source_radio], outputs=[text_prompt, inpaint_prompt, mask_source_radio])
396
+ mask_source_radio.change(fn=change_radio_display, inputs=[task_type, mask_source_radio], outputs=[text_prompt, inpaint_prompt, mask_source_radio])
397
 
398
  DESCRIPTION = '### This demo from [Grounded-Segment-Anything](https://github.com/IDEA-Research/Grounded-Segment-Anything) and kudos to thier excellent works. Welcome everyone to try this out and learn together!'
399
  gr.Markdown(DESCRIPTION)
400
+
401
+ run_button.click(fn=run_anything_task, inputs=[
402
+ input_image, text_prompt, task_type, inpaint_prompt,
403
+ box_threshold,text_threshold, iou_threshold, inpaint_mode,
404
+ mask_source_radio, remove_mode, remove_mask_extend],
405
+ outputs=[gallery, gallery], show_progress=True, queue=True)
406
 
407
  block.launch(debug=args.debug, share=args.share, show_api=False, show_error=True)