Files changed (1) hide show
  1. app.py +185 -315
app.py CHANGED
@@ -1,337 +1,207 @@
1
- import os, sys
2
- import random
3
- import warnings
4
-
5
- os.system("python -m pip install -e segment_anything")
6
- os.system("python -m pip install -e GroundingDINO")
7
- os.system("pip install --upgrade diffusers[torch]")
8
- os.system("pip install opencv-python pycocotools matplotlib onnxruntime onnx ipykernel")
9
- os.system("wget https://github.com/IDEA-Research/Grounded-Segment-Anything/raw/main/assets/demo1.jpg")
10
- os.system("wget https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/groundingdino_swint_ogc.pth")
11
- os.system("wget https://huggingface.co/spaces/mrtlive/segment-anything-model/resolve/main/sam_vit_h_4b8939.pth")
12
- sys.path.append(os.path.join(os.getcwd(), "GroundingDINO"))
13
- sys.path.append(os.path.join(os.getcwd(), "segment_anything"))
14
- warnings.filterwarnings("ignore")
15
-
16
  import gradio as gr
17
- import argparse
18
-
19
  import numpy as np
20
- import torch
21
- import torchvision
22
- from PIL import Image, ImageDraw, ImageFont
23
 
24
- # Grounding DINO
25
- import GroundingDINO.groundingdino.datasets.transforms as T
26
- from GroundingDINO.groundingdino.models import build_model
27
- from GroundingDINO.groundingdino.util.slconfig import SLConfig
28
- from GroundingDINO.groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap
 
29
 
30
- # segment anything
31
- from segment_anything import build_sam, SamPredictor
32
- import numpy as np
33
 
34
- # diffusers
35
- import torch
36
- from diffusers import StableDiffusionInpaintPipeline
37
 
38
- # BLIP
39
- from transformers import BlipProcessor, BlipForConditionalGeneration
 
40
 
 
 
41
 
42
- def generate_caption(processor, blip_model, raw_image):
43
- # unconditional image captioning
44
- inputs = processor(raw_image, return_tensors="pt").to(
45
- "cuda", torch.float16)
46
- out = blip_model.generate(**inputs)
47
- caption = processor.decode(out[0], skip_special_tokens=True)
48
- return caption
49
 
 
50
 
51
- def transform_image(image_pil):
52
 
53
- transform = T.Compose(
54
- [
55
- T.RandomResize([800], max_size=1333),
56
- T.ToTensor(),
57
- T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
58
- ]
59
- )
60
- image, _ = transform(image_pil, None) # 3, h, w
61
- return image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
 
63
 
64
- def load_model(model_config_path, model_checkpoint_path, device):
65
- args = SLConfig.fromfile(model_config_path)
66
- args.device = device
67
- model = build_model(args)
68
- checkpoint = torch.load(model_checkpoint_path, map_location="cpu")
69
- load_res = model.load_state_dict(
70
- clean_state_dict(checkpoint["model"]), strict=False)
71
- print(load_res)
72
- _ = model.eval()
73
- return model
74
 
 
 
 
 
 
 
75
 
76
- def get_grounding_output(model, image, caption, box_threshold, text_threshold, with_logits=True):
77
- caption = caption.lower()
78
- caption = caption.strip()
79
- if not caption.endswith("."):
80
- caption = caption + "."
 
 
 
81
 
82
  with torch.no_grad():
83
- outputs = model(image[None], captions=[caption])
84
- logits = outputs["pred_logits"].cpu().sigmoid()[0] # (nq, 256)
85
- boxes = outputs["pred_boxes"].cpu()[0] # (nq, 4)
86
- logits.shape[0]
87
-
88
- # filter output
89
- logits_filt = logits.clone()
90
- boxes_filt = boxes.clone()
91
- filt_mask = logits_filt.max(dim=1)[0] > box_threshold
92
- logits_filt = logits_filt[filt_mask] # num_filt, 256
93
- boxes_filt = boxes_filt[filt_mask] # num_filt, 4
94
- logits_filt.shape[0]
95
-
96
- # get phrase
97
- tokenlizer = model.tokenizer
98
- tokenized = tokenlizer(caption)
99
- # build pred
100
- pred_phrases = []
101
- scores = []
102
- for logit, box in zip(logits_filt, boxes_filt):
103
- pred_phrase = get_phrases_from_posmap(
104
- logit > text_threshold, tokenized, tokenlizer)
105
- if with_logits:
106
- pred_phrases.append(
107
- pred_phrase + f"({str(logit.max().item())[:4]})")
108
- else:
109
- pred_phrases.append(pred_phrase)
110
- scores.append(logit.max().item())
111
-
112
- return boxes_filt, torch.Tensor(scores), pred_phrases
113
-
114
-
115
- def draw_mask(mask, draw, random_color=False):
116
- if random_color:
117
- color = (random.randint(0, 255), random.randint(
118
- 0, 255), random.randint(0, 255), 153)
119
- else:
120
- color = (30, 144, 255, 153)
121
-
122
- nonzero_coords = np.transpose(np.nonzero(mask))
123
-
124
- for coord in nonzero_coords:
125
- draw.point(coord[::-1], fill=color)
126
-
127
-
128
- def draw_box(box, draw, label):
129
- # random color
130
- color = tuple(np.random.randint(0, 255, size=3).tolist())
131
-
132
- draw.rectangle(((box[0], box[1]), (box[2], box[3])),
133
- outline=color, width=2)
134
-
135
- if label:
136
- font = ImageFont.load_default()
137
- if hasattr(font, "getbbox"):
138
- bbox = draw.textbbox((box[0], box[1]), str(label), font)
139
- else:
140
- w, h = draw.textsize(str(label), font)
141
- bbox = (box[0], box[1], w + box[0], box[1] + h)
142
- draw.rectangle(bbox, fill=color)
143
- draw.text((box[0], box[1]), str(label), fill="white")
144
-
145
- draw.text((box[0], box[1]), label)
146
-
147
-
148
- config_file = 'GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py'
149
- ckpt_repo_id = "ShilongLiu/GroundingDINO"
150
- ckpt_filenmae = "groundingdino_swint_ogc.pth"
151
- sam_checkpoint = 'sam_vit_h_4b8939.pth'
152
- output_dir = "outputs"
153
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
154
-
155
-
156
- blip_processor = None
157
- blip_model = None
158
- groundingdino_model = None
159
- sam_predictor = None
160
- inpaint_pipeline = None
161
-
162
-
163
- def run_grounded_sam(input_image, text_prompt, task_type, inpaint_prompt, box_threshold, text_threshold, iou_threshold, inpaint_mode):
164
-
165
- global blip_processor, blip_model, groundingdino_model, sam_predictor, inpaint_pipeline
166
-
167
- # make dir
168
- os.makedirs(output_dir, exist_ok=True)
169
- # load image
170
- image_pil = input_image.convert("RGB")
171
- transformed_image = transform_image(image_pil)
172
-
173
- if groundingdino_model is None:
174
- groundingdino_model = load_model(
175
- config_file, ckpt_filenmae, device=device)
176
-
177
- if task_type == 'automatic':
178
- # generate caption and tags
179
- # use Tag2Text can generate better captions
180
- # https://huggingface.co/spaces/xinyu1205/Tag2Text
181
- # but there are some bugs...
182
- blip_processor = blip_processor or BlipProcessor.from_pretrained(
183
- "Salesforce/blip-image-captioning-large")
184
- blip_model = blip_model or BlipForConditionalGeneration.from_pretrained(
185
- "Salesforce/blip-image-captioning-large", torch_dtype=torch.float16).to("cuda")
186
- text_prompt = generate_caption(blip_processor, blip_model, image_pil)
187
- print(f"Caption: {text_prompt}")
188
-
189
- # run grounding dino model
190
- boxes_filt, scores, pred_phrases = get_grounding_output(
191
- groundingdino_model, transformed_image, text_prompt, box_threshold, text_threshold
192
  )
193
 
194
- size = image_pil.size
195
-
196
- # process boxes
197
- H, W = size[1], size[0]
198
- for i in range(boxes_filt.size(0)):
199
- boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H])
200
- boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
201
- boxes_filt[i][2:] += boxes_filt[i][:2]
202
-
203
- boxes_filt = boxes_filt.cpu()
204
-
205
- # nms
206
- print(f"Before NMS: {boxes_filt.shape[0]} boxes")
207
- nms_idx = torchvision.ops.nms(
208
- boxes_filt, scores, iou_threshold).numpy().tolist()
209
- boxes_filt = boxes_filt[nms_idx]
210
- pred_phrases = [pred_phrases[idx] for idx in nms_idx]
211
- print(f"After NMS: {boxes_filt.shape[0]} boxes")
212
-
213
- if task_type == 'seg' or task_type == 'inpainting' or task_type == 'automatic':
214
- if sam_predictor is None:
215
- # initialize SAM
216
- assert sam_checkpoint, 'sam_checkpoint is not found!'
217
- sam = build_sam(checkpoint=sam_checkpoint)
218
- sam.to(device=device)
219
- sam_predictor = SamPredictor(sam)
220
-
221
- image = np.array(image_pil)
222
- sam_predictor.set_image(image)
223
-
224
- if task_type == 'automatic':
225
- # use NMS to handle overlapped boxes
226
- print(f"Revise caption with number: {text_prompt}")
227
-
228
- transformed_boxes = sam_predictor.transform.apply_boxes_torch(
229
- boxes_filt, image.shape[:2]).to(device)
230
-
231
- masks, _, _ = sam_predictor.predict_torch(
232
- point_coords=None,
233
- point_labels=None,
234
- boxes=transformed_boxes,
235
- multimask_output=False,
236
- )
237
-
238
- # masks: [1, 1, 512, 512]
239
-
240
- if task_type == 'det':
241
- image_draw = ImageDraw.Draw(image_pil)
242
- for box, label in zip(boxes_filt, pred_phrases):
243
- draw_box(box, image_draw, label)
244
-
245
- return [image_pil]
246
- elif task_type == 'seg' or task_type == 'automatic':
247
-
248
- mask_image = Image.new('RGBA', size, color=(0, 0, 0, 0))
249
-
250
- mask_draw = ImageDraw.Draw(mask_image)
251
- for mask in masks:
252
- draw_mask(mask[0].cpu().numpy(), mask_draw, random_color=True)
253
-
254
- image_draw = ImageDraw.Draw(image_pil)
255
-
256
- for box, label in zip(boxes_filt, pred_phrases):
257
- draw_box(box, image_draw, label)
258
-
259
- if task_type == 'automatic':
260
- image_draw.text((10, 10), text_prompt, fill='black')
261
-
262
- image_pil = image_pil.convert('RGBA')
263
- image_pil.alpha_composite(mask_image)
264
- return [image_pil, mask_image]
265
- elif task_type == 'inpainting':
266
- assert inpaint_prompt, 'inpaint_prompt is not found!'
267
- # inpainting pipeline
268
- if inpaint_mode == 'merge':
269
- masks = torch.sum(masks, dim=0).unsqueeze(0)
270
- masks = torch.where(masks > 0, True, False)
271
- # simply choose the first mask, which will be refine in the future release
272
- mask = masks[0][0].cpu().numpy()
273
- mask_pil = Image.fromarray(mask)
274
-
275
- if inpaint_pipeline is None:
276
- inpaint_pipeline = StableDiffusionInpaintPipeline.from_pretrained(
277
- "runwayml/stable-diffusion-inpainting", torch_dtype=torch.float16
278
- )
279
- inpaint_pipeline = inpaint_pipeline.to("cuda")
280
-
281
- image = inpaint_pipeline(prompt=inpaint_prompt, image=image_pil.resize(
282
- (512, 512)), mask_image=mask_pil.resize((512, 512))).images[0]
283
- image = image.resize(size)
284
-
285
- return [image, mask_pil]
286
- else:
287
- print("task_type:{} error!".format(task_type))
288
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
289
 
290
  if __name__ == "__main__":
291
- parser = argparse.ArgumentParser("Grounded SAM demo", add_help=True)
292
- parser.add_argument("--debug", action="store_true",
293
- help="using debug mode")
294
- parser.add_argument("--share", action="store_true", help="share the app")
295
- parser.add_argument('--no-gradio-queue', action="store_true",
296
- help='path to the SAM checkpoint')
297
- args = parser.parse_args()
298
-
299
- print(args)
300
-
301
- block = gr.Blocks()
302
- if not args.no_gradio_queue:
303
- block = block.queue()
304
-
305
-
306
- with block:
307
- with gr.Row():
308
- with gr.Column():
309
- input_image = gr.Image(
310
- source='upload', type="pil", value="demo1.jpg")
311
- task_type = gr.Dropdown(
312
- ["det", "seg", "inpainting", "automatic"], value="seg", label="task_type")
313
- text_prompt = gr.Textbox(label="Text Prompt", placeholder="bear . beach .")
314
- inpaint_prompt = gr.Textbox(label="Inpaint Prompt", placeholder="A dinosaur, detailed, 4K.")
315
- run_button = gr.Button(label="Run")
316
- with gr.Accordion("Advanced options", open=False):
317
- box_threshold = gr.Slider(
318
- label="Box Threshold", minimum=0.0, maximum=1.0, value=0.3, step=0.001
319
- )
320
- text_threshold = gr.Slider(
321
- label="Text Threshold", minimum=0.0, maximum=1.0, value=0.25, step=0.001
322
- )
323
- iou_threshold = gr.Slider(
324
- label="IOU Threshold", minimum=0.0, maximum=1.0, value=0.8, step=0.001
325
- )
326
- inpaint_mode = gr.Dropdown(
327
- ["merge", "first"], value="merge", label="inpaint_mode")
328
-
329
- with gr.Column():
330
- gallery = gr.Gallery(
331
- label="Generated images", show_label=False, elem_id="gallery"
332
- ).style(preview=True, grid=2, object_fit="scale-down")
333
-
334
- run_button.click(fn=run_grounded_sam, inputs=[
335
- input_image, text_prompt, task_type, inpaint_prompt, box_threshold, text_threshold, iou_threshold, inpaint_mode], outputs=gallery)
336
-
337
- block.launch(debug=args.debug, share=args.share, show_error=True)
 
1
+ import os
2
+ import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  import gradio as gr
 
 
4
  import numpy as np
5
+ from PIL import Image
 
 
6
 
7
+ from transformers import (
8
+ AutoProcessor,
9
+ AutoModelForZeroShotObjectDetection,
10
+ BlipProcessor,
11
+ BlipForConditionalGeneration
12
+ )
13
 
14
+ from segment_anything import sam_model_registry, SamPredictor
 
 
15
 
16
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
 
 
17
 
18
+ # --------------------------------------------------
19
+ # MODELS
20
+ # --------------------------------------------------
21
 
22
+ DINO_MODEL = "IDEA-Research/grounding-dino-base"
23
+ BLIP_MODEL = "Salesforce/blip-image-captioning-base"
24
 
25
+ SAM_TYPE = "vit_b"
26
+ SAM_CHECKPOINT = "sam_vit_b.pth"
27
+ SAM_URL = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth"
 
 
 
 
28
 
29
+ BOX_THRESHOLD = 0.3
30
 
 
31
 
32
+ # --------------------------------------------------
33
+ # DOWNLOAD SAM
34
+ # --------------------------------------------------
35
+
36
+ if not os.path.exists(SAM_CHECKPOINT):
37
+ import urllib.request
38
+ print("Downloading SAM checkpoint...")
39
+ urllib.request.urlretrieve(SAM_URL, SAM_CHECKPOINT)
40
+
41
+
42
+ # --------------------------------------------------
43
+ # LOAD MODELS
44
+ # --------------------------------------------------
45
+
46
+ print("Loading GroundingDINO...")
47
+ processor = AutoProcessor.from_pretrained(DINO_MODEL)
48
+ dino = AutoModelForZeroShotObjectDetection.from_pretrained(DINO_MODEL).to(DEVICE)
49
+
50
+ print("Loading SAM...")
51
+ sam = sam_model_registry[SAM_TYPE](checkpoint=SAM_CHECKPOINT)
52
+ sam.to(device=DEVICE)
53
+ predictor = SamPredictor(sam)
54
+
55
+ print("Loading BLIP...")
56
+ blip_processor = BlipProcessor.from_pretrained(BLIP_MODEL)
57
+ blip_model = BlipForConditionalGeneration.from_pretrained(BLIP_MODEL).to(DEVICE)
58
+
59
+
60
+ # --------------------------------------------------
61
+ # BLIP CAPTION
62
+ # --------------------------------------------------
63
 
64
+ def generate_caption(image):
65
 
66
+ inputs = blip_processor(image, return_tensors="pt").to(DEVICE)
 
 
 
 
 
 
 
 
 
67
 
68
+ with torch.no_grad():
69
+ out = blip_model.generate(**inputs)
70
+
71
+ caption = blip_processor.decode(out[0], skip_special_tokens=True)
72
+
73
+ return caption
74
 
75
+
76
+ # --------------------------------------------------
77
+ # DETECT OBJECTS
78
+ # --------------------------------------------------
79
+
80
+ def detect(image, prompt):
81
+
82
+ inputs = processor(images=image, text=prompt, return_tensors="pt").to(DEVICE)
83
 
84
  with torch.no_grad():
85
+ outputs = dino(**inputs)
86
+
87
+ results = processor.post_process_grounded_object_detection(
88
+ outputs,
89
+ target_sizes=[image.size[::-1]],
90
+ )[0]
91
+
92
+ boxes = results["boxes"]
93
+ scores = results["scores"]
94
+
95
+ keep = scores > BOX_THRESHOLD
96
+
97
+ return boxes[keep]
98
+
99
+
100
+ # --------------------------------------------------
101
+ # DRAW BOXES
102
+ # --------------------------------------------------
103
+
104
+ def draw_boxes(image, boxes):
105
+
106
+ image_np = np.array(image)
107
+ result = image_np.copy()
108
+
109
+ for box in boxes:
110
+
111
+ x1, y1, x2, y2 = box.cpu().numpy().astype(int)
112
+
113
+ result[y1:y1+3, x1:x2] = [255, 0, 0]
114
+ result[y2:y2+3, x1:x2] = [255, 0, 0]
115
+ result[y1:y2, x1:x1+3] = [255, 0, 0]
116
+ result[y1:y2, x2:x2+3] = [255, 0, 0]
117
+
118
+ return Image.fromarray(result)
119
+
120
+
121
+ # --------------------------------------------------
122
+ # SEGMENT
123
+ # --------------------------------------------------
124
+
125
+ def segment(image, prompt):
126
+
127
+ image = image.convert("RGB")
128
+ image_np = np.array(image)
129
+
130
+ boxes = detect(image, prompt)
131
+
132
+ if len(boxes) == 0:
133
+ return image
134
+
135
+ predictor.set_image(image_np)
136
+
137
+ boxes = boxes.to(DEVICE)
138
+
139
+ transformed = predictor.transform.apply_boxes_torch(
140
+ boxes, image_np.shape[:2]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
  )
142
 
143
+ masks, _, _ = predictor.predict_torch(
144
+ point_coords=None,
145
+ point_labels=None,
146
+ boxes=transformed,
147
+ multimask_output=False,
148
+ )
149
+
150
+ result = image_np.copy()
151
+
152
+ for mask in masks:
153
+
154
+ m = mask[0].cpu().numpy()
155
+
156
+ result[m > 0] = (
157
+ result[m > 0] * 0.5 + np.array([0, 255, 0]) * 0.5
158
+ ).astype(np.uint8)
159
+
160
+ return Image.fromarray(result)
161
+
162
+
163
+ # --------------------------------------------------
164
+ # PIPELINE
165
+ # --------------------------------------------------
166
+
167
+ def run_pipeline(image, prompt, mode):
168
+
169
+ if mode == "seg":
170
+ return segment(image, prompt)
171
+
172
+ if mode == "det":
173
+
174
+ boxes = detect(image, prompt)
175
+
176
+ return draw_boxes(image, boxes)
177
+
178
+ if mode == "automatic":
179
+
180
+ caption = generate_caption(image)
181
+
182
+ print("BLIP caption:", caption)
183
+
184
+ return segment(image, caption)
185
+
186
+
187
+ # --------------------------------------------------
188
+ # UI
189
+ # --------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
 
191
+ demo = gr.Interface(
192
+ fn=run_pipeline,
193
+ inputs=[
194
+ gr.Image(type="pil"),
195
+ gr.Textbox(label="Prompt", value="person"),
196
+ gr.Dropdown(
197
+ ["seg", "det", "automatic"],
198
+ value="seg",
199
+ label="Mode"
200
+ ),
201
+ ],
202
+ outputs=gr.Image(),
203
+ title="GroundingDINO + SAM + BLIP (CPU version)",
204
+ )
205
 
206
  if __name__ == "__main__":
207
+ demo.launch()