| | import os |
| | import cv2 |
| | import time |
| | import torch |
| | import spaces |
| | import subprocess |
| | import numpy as np |
| | import gradio as gr |
| | import urllib.request |
| | from PIL import Image, ImageDraw |
| | import matplotlib.pyplot as plt |
| |
|
| | from Garage.models.GroundedSegmentAnything.segment_anything.segment_anything import SamPredictor, build_sam, sam_model_registry |
| | from Garage.models.GroundedSegmentAnything.GroundingDINO.groundingdino.util.inference import Model |
| | from Garage import Augmenter |
| |
|
| |
|
| | MODEL_DICT = dict( |
| | vit_h="https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth", |
| | vit_l="https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth", |
| | vit_b="https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth", |
| | ) |
| |
|
| | GROUNDING_DINO_CONFIG_PATH = "Garage/models/GroundedSegmentAnything/GroundingDINO_SwinT_OGC.py" |
| | GROUNDING_DINO_CHECKPOINT_PATH = "https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth" |
| | SAM_CHECKPOINT_PATH = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth" |
| | SAM_ENCODER_VERSION = "vit_h" |
| |
|
| | class GradioWindow(): |
| | def __init__(self) -> None: |
| | self.points = [] |
| | self.mask = [] |
| | self.selected_mask = None |
| | self.segmentation_mask = [] |
| | self.concatenated_masks = None |
| | self.examples_masks = { |
| | 0: ["dog", "examples/dog_mask.jpg"], |
| | 1: ["bread", "examples/bread_mask.jpg"], |
| | 2: ["room", "examples/room_mask.jpg"], |
| | 3: ["spoon", "examples/spoon_mask.jpg"], |
| | 4: ["cat", "examples/image_mask.jpg"], |
| | } |
| |
|
| | self.GROUNDING_DINO_CONFIG_PATH = GROUNDING_DINO_CONFIG_PATH |
| | self.GROUNDING_DINO_CHECKPOINT_PATH = GROUNDING_DINO_CHECKPOINT_PATH |
| | self.model_type = SAM_ENCODER_VERSION |
| | self.SAM_CHECKPOINT_PATH = SAM_CHECKPOINT_PATH |
| |
|
| | |
| | self.device = "cpu" |
| | |
| | |
| | self.augmenter = Augmenter(device=self.device) |
| | self.setup_model() |
| | self.main() |
| |
|
| | def main(self): |
| | with gr.Blocks() as self.demo: |
| | with gr.Row(): |
| | input_img = gr.Image(type="pil", label="Input image", interactive=True) |
| | selected_mask = gr.Image(type="pil", label="Selected Mask", interactive=True) |
| | segmented_img = gr.Image(type="pil", label="Selected Segment") |
| |
|
| | with gr.Row(): |
| | with gr.Group(): |
| | gr.Markdown( |
| | "## Grounded Segmentation\n" |
| | "#### This tool segments the object in the image based on the text prompt via GroundedSAM model. " |
| | "You can also load the mask of the object to segment or choose one of the examples below.\n" |
| | ) |
| | self.current_object = gr.Textbox(label="Current object") |
| | with gr.Accordion("Advanced options", open=False): |
| | self.use_mask = gr.Checkbox(label="Use segmentation mask", value=False) |
| | box_threshold = gr.Slider(minimum=0.0, maximum=1.0, value=0.25, label="Box threshold") |
| | text_threshold = gr.Slider(minimum=0.0, maximum=1.0, value=0.25, label="Text threshold") |
| |
|
| | segment_object = gr.Button("Segment object") |
| | |
| | with gr.Column(): |
| | gr.Examples( |
| | label="Images Examples", |
| | examples=[ |
| | ["examples/dog.jpg"], |
| | ["examples/bread.png"], |
| | ["examples/room.jpg"], |
| | ["examples/spoon.png"], |
| | ["examples/image.jpg"], |
| | ], |
| | inputs=[input_img], |
| | examples_per_page=5 |
| | ) |
| | gr.Examples( |
| | label="Mask Examples", |
| | examples=[ |
| | [self.examples_masks[0][1]], |
| | [self.examples_masks[1][1]], |
| | [self.examples_masks[2][1]], |
| | [self.examples_masks[3][1]], |
| | [self.examples_masks[4][1]], |
| | ], |
| | inputs=[selected_mask, input_img], |
| | outputs=[segmented_img, self.current_object, self.use_mask], |
| | fn=self.set_mask, |
| | run_on_click=True |
| | ) |
| |
|
| | with gr.Row(): |
| | with gr.Column(): |
| | with gr.Group(): |
| | gr.Markdown( |
| | "## Augmentation\n" |
| | "#### This tool generates an augmented image based on the input image, the object to augment, and the target object. " |
| | "If you don't specify the target object, the model will generate a random object. " |
| | "You can also specify the number of steps, guidance scale, and seed for the generation process.\n" |
| | ) |
| | self.target_object = gr.Textbox(label="Target object") |
| |
|
| | with gr.Accordion("Generation options", open=False): |
| | self.iter_number = gr.Number(value=50, label="Steps") |
| | self.guidance_scale = gr.Number(value=5, label="Guidance Scale") |
| | self.seed = gr.Number(value=1, label="Seed") |
| | self.return_prompt = gr.Checkbox(value=True, label="Show generated prompt") |
| |
|
| | enter_prompt = gr.Button("Augment Image") |
| |
|
| | with gr.Column(): |
| | augmented_img = gr.Image(type="pil", label="Augmented Image") |
| | generated_prompt = gr.Markdown( |
| | f"<div class=\"message\" style=\"text-align: center; \ |
| | font-size: 18px;\"></div>", |
| | visible=True) |
| |
|
| | |
| | selected_mask.upload( |
| | self.set_mask, |
| | inputs=[selected_mask, input_img], |
| | outputs=[segmented_img, self.current_object, self.use_mask], |
| | ) |
| |
|
| | segment_object.click( |
| | self.detect, |
| | inputs=[input_img, self.current_object, |
| | self.use_mask, box_threshold, |
| | text_threshold], |
| | outputs=[segmented_img, selected_mask] |
| | ) |
| |
|
| | self.use_mask.change( |
| | fn=self.change_mask_type, |
| | inputs=[input_img, self.use_mask], |
| | outputs=[selected_mask, segmented_img], |
| | ) |
| |
|
| | segmented_img.select( |
| | self.select_mask, |
| | inputs=[input_img], |
| | outputs=[selected_mask, segmented_img], |
| | ) |
| |
|
| | enter_prompt.click( |
| | self.augment_image, |
| | inputs=[input_img, self.current_object, self.target_object, |
| | self.iter_number, self.guidance_scale, self.seed, self.return_prompt], |
| | outputs=[augmented_img, generated_prompt], |
| | ) |
| |
|
| |
|
| | def setup_model(self) -> SamPredictor: |
| | self.sam = sam_model_registry["vit_h"]() |
| | self.sam.load_state_dict(torch.utils.model_zoo.load_url(MODEL_DICT["vit_h"])) |
| | self.sam.to(device=self.device) |
| | self.sam_predictor = SamPredictor(self.sam) |
| |
|
| | self.grounding_dino_model = Model( |
| | model_config_path=self.GROUNDING_DINO_CONFIG_PATH, |
| | model_checkpoint_path=GROUNDING_DINO_CHECKPOINT_PATH, |
| | device=self.device |
| | ) |
| | |
| | print("MODELS LOADED! Device:", self.device) |
| |
|
| | def change_mask_type(self, image, is_segmmask): |
| | self.selected_mask = None |
| | masks = [] |
| | self.mask = [] |
| | if is_segmmask: |
| | for segm_mask in self.segmentation_mask: |
| | gray_mask = np.array(segm_mask) |
| | if gray_mask.ndim == 3: |
| | gray_mask = gray_mask[:, :, 0] |
| | gray_mask = np.where(gray_mask > 200, True, False) |
| | masks.append(gray_mask) |
| | self.mask.append(Image.fromarray(gray_mask)) |
| | res, common_mask = self.concatenate_masks(masks, image) |
| | else: |
| | for segm_mask in self.segmentation_mask: |
| | mask = self.get_bbox_mask(segm_mask) |
| | gray_mask = np.array(mask) |
| | masks.append(gray_mask) |
| | self.mask.append(Image.fromarray(gray_mask)) |
| | res, common_mask = self.concatenate_masks(masks, image) |
| | return common_mask, res |
| |
|
| | def get_bbox_mask(self, mask): |
| | bbox = mask.getbbox() |
| | new_mask = Image.new("L", mask.size, 0) |
| | draw = ImageDraw.Draw(new_mask) |
| | if bbox: |
| | draw.rectangle(bbox, fill=255) |
| | return new_mask |
| |
|
| | def select_mask(self, image: Image, evt: gr.SelectData): |
| | self.points = [evt.index[0], evt.index[1]] |
| | selected_mask = np.zeros_like(image) |
| | self.selected_mask = None |
| | for mask in self.mask: |
| | mask = np.array(mask) |
| | plt.imshow(mask) |
| | plt.show() |
| | print(f"SELECT MASK {mask.shape}, unique {np.unique(mask)}") |
| | if mask[self.points[1]][self.points[0]]: |
| | self.selected_mask = Image.fromarray(mask) |
| | color = np.array([30 / 255, 144 / 255, 255 / 255]) |
| | selected_mask[mask > 0] = color.reshape(1, 1, -1) * 255 |
| | selected_mask = Image.fromarray(selected_mask, mode="RGB") |
| | break |
| |
|
| | res = self.show_mask(selected_mask, image) |
| | self.concatenated_masks = res |
| | return self.selected_mask, res |
| | |
| | def set_mask(self, mask: Image, image: Image): |
| | self.selected_mask = mask |
| | self.segmentation_mask = [mask] |
| | current_object = None |
| |
|
| | for key, value in self.examples_masks.items(): |
| | m = Image.open(value[1]) |
| | if np.array_equal(np.array(m), np.array(mask)): |
| | current_object = value[0] |
| | break |
| |
|
| | gray_mask = np.array(mask) |
| | gray_mask = gray_mask[:, :, 0] |
| | bin_mask = np.where(gray_mask > 200, True, False) |
| | print(f"SET MASK {bin_mask.shape}, unique {np.unique(bin_mask)}") |
| |
|
| | _, common_mask = self.concatenate_masks([bin_mask], image) |
| | self.mask = [Image.fromarray(bin_mask)] |
| | res = self.show_mask(common_mask, image) |
| | self.concatenated_masks = res |
| | return res, current_object, True |
| |
|
| | def detect(self, image: Image, prompt: str, is_segmmask: bool, |
| | box_threshold: float, text_threshold: float): |
| | detections = self.grounding_dino_model.predict_with_classes( |
| | image=cv2.cvtColor(np.array(image), cv2.COLOR_BGR2RGB), |
| | classes=[prompt], |
| | box_threshold=box_threshold, |
| | text_threshold=text_threshold, |
| | ) |
| |
|
| | detections.mask = self.segment( |
| | sam_predictor=self.sam_predictor, |
| | image=cv2.cvtColor(np.array(image), cv2.COLOR_BGR2RGB), |
| | xyxy=detections.xyxy |
| | ) |
| |
|
| | if len(detections.mask) == 0: |
| | return np.array(image), Image.fromarray(np.zeros_like(np.array(image))) |
| | |
| | self.segmentation_mask = [] |
| | for mask in detections.mask: |
| | self.segmentation_mask.append(Image.fromarray(mask)) |
| |
|
| | if is_segmmask: |
| | image, common_mask = self.concatenate_masks(detections.mask, image) |
| | else: |
| | masks = [] |
| | for mask in detections.mask: |
| | bbox_mask = self.get_bbox_mask(Image.fromarray(mask)) |
| | masks.append(np.array(bbox_mask)) |
| | image, common_mask = self.concatenate_masks(masks, image) |
| |
|
| | return image, common_mask |
| | |
| | def concatenate_masks(self, masks: np.ndarray, image: Image) -> np.ndarray: |
| | self.mask = [] |
| | random_color = False |
| | common_mask = np.zeros_like(image) |
| | for i, mask in enumerate(masks): |
| | if random_color: |
| | color = np.concatenate([np.random.random(3)], axis=0) |
| | else: |
| | color = np.array([30 / 255, 144 / 255, 255 / 255]) |
| | |
| | self.mask.append(Image.fromarray(mask)) |
| | common_mask[mask > 0] = color.reshape(1, 1, -1) * 255 |
| | random_color = True |
| | |
| | common_mask = Image.fromarray(common_mask, mode="RGB") |
| | image = self.show_mask(common_mask, image, random_color) |
| |
|
| | common_mask = np.where(np.array(common_mask) != 0, 255, 0).astype(np.uint8) |
| | return Image.fromarray(image), Image.fromarray(common_mask) |
| | |
| | def show_mask(self, mask: Image, image: Image, |
| | random_color: bool = False) -> np.ndarray: |
| | """Visualize a mask on top of an image. |
| | Args: |
| | mask (Image): A 2D array of shape (H, W, 3). |
| | image (Image): A 3D array of shape (H, W, 3). |
| | random_color (bool): Whether to use a random color for the mask. |
| | Returns: |
| | np.ndarray: A 3D array of shape (H, W, 3) with the mask |
| | visualized on top of the image. |
| | """ |
| | mask, image = np.array(mask), np.array(image) |
| | target_size = (image.shape[1], image.shape[0]) |
| | mask = cv2.resize(mask, target_size, interpolation=cv2.INTER_NEAREST) |
| | image = cv2.addWeighted(image, 0.7, mask, 0.3, 0) |
| | return image |
| |
|
| | |
| | def segment(self, sam_predictor: SamPredictor, image: np.ndarray, xyxy: np.ndarray) -> np.ndarray: |
| | sam_predictor.set_image(image) |
| | result_masks = [] |
| | for box in xyxy: |
| | masks, scores, logits = sam_predictor.predict( |
| | box=box, |
| | multimask_output=True |
| | ) |
| | index = np.argmax(scores) |
| | result_masks.append(masks[index]) |
| | return np.array(result_masks) |
| |
|
| | |
| | def augment_image(self, image: Image, |
| | current_object: str, new_objects_list: str, |
| | ddim_steps: int, guidance_scale: int, seed: int, return_prompt: str) -> tuple: |
| | |
| | if self.selected_mask: |
| | mask = self.selected_mask |
| | else: |
| | mask = self.mask[np.random.choice(len(self.mask))] |
| |
|
| | new_objects_list = new_objects_list.split(", ") |
| |
|
| | result, (prompt, _) = self.augmenter( |
| | image=image, |
| | mask=mask, |
| | current_object=current_object, |
| | new_objects_list=new_objects_list, |
| | ddim_steps=ddim_steps, |
| | guidance_scale=guidance_scale, |
| | seed=seed, |
| | return_prompt=return_prompt |
| | ) |
| |
|
| | |
| | |
| | |
| | |
| | if not return_prompt: |
| | prompt = "" |
| |
|
| | prompt_message = f"<div class=\"message\" style=\"text-align: center; \ |
| | font-size: 18px;\">Generated prompt: {prompt}</div>" |
| | return result, prompt_message |
| | |
| | |
| | if __name__ == "__main__": |
| | window = GradioWindow() |
| | window.demo.launch(share=False) |
| | window.demo.close() |