Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| from matplotlib import pyplot as plt | |
| import numpy as np | |
| from groundingdino.util.inference import load_model, load_image, predict | |
| from segment_anything import SamPredictor, sam_model_registry | |
| from torchvision.ops import box_convert | |
| model_type = "vit_b" | |
| sam_checkpoint = "weights/sam_vit_b.pth" | |
| config = "groundingdino/config/GroundingDINO_SwinT_OGC.py" | |
| dino_checkpoint = "weights/groundingdino_swint_ogc.pth" | |
| sam = sam_model_registry[model_type](checkpoint=sam_checkpoint) | |
| predictor = SamPredictor(sam) | |
| device = "cpu" | |
| model = load_model(config, dino_checkpoint, device) | |
| box_threshold = 0.35 | |
| text_threshold = 0.25 | |
| def show_mask(mask, ax, random_color=False): | |
| if random_color: | |
| color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) | |
| else: | |
| color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6]) | |
| h, w = mask.shape[-2:] | |
| mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) | |
| ax.imshow(mask_image) | |
| def show_box(box, ax, label=None): | |
| x0, y0 = box[0], box[1] | |
| w, h = box[2] - box[0], box[3] - box[1] | |
| ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='red', facecolor=(0, 0, 0, 0), lw=2)) | |
| if label is not None: | |
| ax.text(x0, y0, label, fontsize=12, color='white', backgroundcolor='red', ha='left', va='top') | |
| def extract_object_with_transparent_background(image, masks): | |
| mask_expanded = np.expand_dims(masks[0], axis=-1) | |
| mask_expanded = np.repeat(mask_expanded, 3, axis=-1) | |
| segment = image * mask_expanded | |
| rgba_segment = np.zeros((segment.shape[0], segment.shape[1], 4), dtype=np.uint8) | |
| rgba_segment[:, :, :3] = segment | |
| rgba_segment[:, :, 3] = masks[0] * 255 | |
| return rgba_segment | |
| def extract_remaining_image(image, masks): | |
| inverse_mask = np.logical_not(masks[0]) | |
| inverse_mask_expanded = np.expand_dims(inverse_mask, axis=-1) | |
| inverse_mask_expanded = np.repeat(inverse_mask_expanded, 3, axis=-1) | |
| remaining_image = image * inverse_mask_expanded | |
| return remaining_image | |
| def overlay_masks_boxes_on_image(image, masks, boxes, labels, show_masks, show_boxes): | |
| fig, ax = plt.subplots() | |
| ax.imshow(image) | |
| if show_masks: | |
| for mask in masks: | |
| show_mask(mask, ax, random_color=False) | |
| if show_boxes: | |
| for input_box, label in zip(boxes, labels): | |
| show_box(input_box, ax, label) | |
| ax.axis('off') | |
| plt.subplots_adjust(left=0, right=1, top=1, bottom=0, wspace=0, hspace=0) | |
| plt.margins(0, 0) | |
| fig.canvas.draw() | |
| output_image = np.array(fig.canvas.buffer_rgba()) | |
| plt.close(fig) | |
| return output_image | |
| def detect_objects(image, prompt, show_masks=True, show_boxes=True, crop_options="No crop"): | |
| image_source, image = load_image(image) | |
| predictor.set_image(image_source) | |
| boxes, logits, phrases = predict( | |
| model=model, | |
| image=image, | |
| caption=prompt, | |
| box_threshold=box_threshold, | |
| text_threshold=text_threshold, | |
| device=device | |
| ) | |
| h, w, _ = image_source.shape | |
| boxes = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy") * torch.Tensor([w, h, w, h]) | |
| boxes = np.round(boxes.numpy()).astype(int) | |
| labels = [f"{phrase} {logit:.2f}" for phrase, logit in zip(phrases, logits)] | |
| masks_list = [] | |
| res_json = {"prompt": prompt, "objects": []} | |
| output_image_paths = [] | |
| for i, (input_box, label, phrase, logit) in enumerate(zip(boxes, labels, phrases, logits.tolist())): | |
| x1, y1, x2, y2 = input_box | |
| width = x2 - x1 | |
| height = y2 - y1 | |
| avg_size = (width + height) / 2 | |
| d = avg_size * 0.1 | |
| center_point = np.array([(x1 + x2) / 2, (y1 + y2) / 2]) | |
| points = [] | |
| points.append([center_point[0], center_point[1] - d]) | |
| points.append([center_point[0], center_point[1] + d]) | |
| points.append([center_point[0] - d, center_point[1]]) | |
| points.append([center_point[0] + d, center_point[1]]) | |
| input_point = np.array(points) | |
| input_label = np.array([1] * len(input_point)) | |
| masks, scores, logits = predictor.predict( | |
| point_coords=input_point, | |
| point_labels=input_label, | |
| multimask_output=True, | |
| ) | |
| mask_input = logits[np.argmax(scores), :, :] | |
| masks, _, _ = predictor.predict( | |
| point_coords=input_point, | |
| point_labels=input_label, | |
| mask_input=mask_input[None, :, :], | |
| multimask_output=False | |
| ) | |
| masks_list.append(masks) | |
| composite_image = np.zeros_like(image_source) | |
| rgba_segment = extract_object_with_transparent_background(image_source, masks) | |
| composite_image = np.maximum(composite_image, rgba_segment[:, :, :3]) | |
| cropped_image = composite_image[y1:y2, x1:x2, :] | |
| output_image = overlay_masks_boxes_on_image(cropped_image, [], [], [], False, False) | |
| output_image_path = f'output_image_{i}.jpeg' | |
| plt.imsave(output_image_path, output_image) | |
| output_image_paths.append(output_image_path) | |
| # save object information in json | |
| res_json["objects"].append({ | |
| "label": phrase, | |
| "dino_score": logit, | |
| "sam_score": np.max(scores).item(), | |
| "box": input_box.tolist(), | |
| "center": center_point.tolist(), | |
| "avg_size": avg_size | |
| }) | |
| return [res_json, output_image_paths] | |
| app = gr.Interface( | |
| detect_objects, | |
| inputs=[gr.Image(type='filepath', label="Upload Image"), | |
| gr.Textbox( | |
| label="Object to Detect", | |
| placeholder="Enter any text, comma separated if multiple objects needed", | |
| show_label=True, | |
| lines=1, | |
| )], | |
| outputs=[ | |
| gr.JSON(label="Output JSON"), | |
| gr.Gallery(label="Result"), | |
| ], | |
| examples=[ | |
| ["images/fish.jpg", "fish"], | |
| ["images/birds.png", "bird"], | |
| ["images/bear.png", "bear"], | |
| ["images/penguin.png", "penguin"], | |
| ["images/penn.jpg", "sign board"] | |
| ], | |
| title="Object Detection, Segmentation and Cropping", | |
| description="This app uses DINO to detect objects in an image and then uses SAM to segment and crop the objects.", | |
| ) | |
| app.launch() | |