from typing import Tuple import gradio as gr import spaces import numpy as np import supervision as sv import torch from PIL import Image import dounseen.utils as dounseen_utils from utils.models import load_sam2_models, make_sam2_mask_generators, CHECKPOINT_NAMES, load_dounseen_model # TODO add presentation on YouTube and add link here EXAMPLES = [ [ "https://raw.githubusercontent.com/AnasIbrahim/image_agnostic_segmentation/master/demo/scene_images/example.jpg", "tiny", #obj_000001 "https://raw.githubusercontent.com/AnasIbrahim/image_agnostic_segmentation/master/demo/objects_gallery/obj_000001/20240919_145038.jpg", "https://raw.githubusercontent.com/AnasIbrahim/image_agnostic_segmentation/master/demo/objects_gallery/obj_000001/20240919_145042.jpg", "https://raw.githubusercontent.com/AnasIbrahim/image_agnostic_segmentation/master/demo/objects_gallery/obj_000001/20240919_145045.jpg", "https://raw.githubusercontent.com/AnasIbrahim/image_agnostic_segmentation/master/demo/objects_gallery/obj_000001/20240919_145048.jpg", "https://raw.githubusercontent.com/AnasIbrahim/image_agnostic_segmentation/master/demo/objects_gallery/obj_000001/20240919_145052.jpg", "https://raw.githubusercontent.com/AnasIbrahim/image_agnostic_segmentation/master/demo/objects_gallery/obj_000001/20240919_145055.jpg", "https://raw.githubusercontent.com/AnasIbrahim/image_agnostic_segmentation/master/demo/objects_gallery/obj_000001/20240919_145423.jpg", "https://raw.githubusercontent.com/AnasIbrahim/image_agnostic_segmentation/master/demo/objects_gallery/obj_000001/20240919_145427.jpg", #obj_000002: "https://raw.githubusercontent.com/AnasIbrahim/image_agnostic_segmentation/master/demo/objects_gallery/obj_000002/20240919_145450.jpg", "https://raw.githubusercontent.com/AnasIbrahim/image_agnostic_segmentation/master/demo/objects_gallery/obj_000002/20240919_145454.jpg", "https://raw.githubusercontent.com/AnasIbrahim/image_agnostic_segmentation/master/demo/objects_gallery/obj_000002/20240919_145500.jpg", "https://raw.githubusercontent.com/AnasIbrahim/image_agnostic_segmentation/master/demo/objects_gallery/obj_000002/20240919_145502.jpg", "https://raw.githubusercontent.com/AnasIbrahim/image_agnostic_segmentation/master/demo/objects_gallery/obj_000002/20240919_145506.jpg", "https://raw.githubusercontent.com/AnasIbrahim/image_agnostic_segmentation/master/demo/objects_gallery/obj_000002/20240919_145757.jpg", #obj_000003: "https://raw.githubusercontent.com/AnasIbrahim/image_agnostic_segmentation/master/demo/objects_gallery/obj_000003/20240919_145519.jpg", "https://raw.githubusercontent.com/AnasIbrahim/image_agnostic_segmentation/master/demo/objects_gallery/obj_000003/20240919_145522.jpg", "https://raw.githubusercontent.com/AnasIbrahim/image_agnostic_segmentation/master/demo/objects_gallery/obj_000003/20240919_145526.jpg", "https://raw.githubusercontent.com/AnasIbrahim/image_agnostic_segmentation/master/demo/objects_gallery/obj_000003/20240919_145529.jpg", "https://raw.githubusercontent.com/AnasIbrahim/image_agnostic_segmentation/master/demo/objects_gallery/obj_000003/20240919_145719.jpg", "https://raw.githubusercontent.com/AnasIbrahim/image_agnostic_segmentation/master/demo/objects_gallery/obj_000003/20240919_145727.jpg", "https://raw.githubusercontent.com/AnasIbrahim/image_agnostic_segmentation/master/demo/objects_gallery/obj_000003/20240919_145730.jpg", "https://raw.githubusercontent.com/AnasIbrahim/image_agnostic_segmentation/master/demo/objects_gallery/obj_000003/20240919_145736.jpg", "https://raw.githubusercontent.com/AnasIbrahim/image_agnostic_segmentation/master/demo/objects_gallery/obj_000003/20240919_145744.jpg", ] ] DEVICE = torch.device('cuda') torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__() if torch.cuda.get_device_properties(0).major >= 8: torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True MASK_ANNOTATOR = sv.MaskAnnotator(color_lookup=sv.ColorLookup.INDEX) SAM2_models = load_sam2_models(device=DEVICE) MASK_GENERATORS = make_sam2_mask_generators(SAM2_models) DOUNSEEN_MODEL = load_dounseen_model(device=DEVICE) def remake_sam2_mask_generators(points_per_side): global MASK_GENERATORS MASK_GENERATORS = make_sam2_mask_generators(SAM2_models, points_per_side) def setup_example_gallery( image, checkpoint, obj_1_img_1, obj_1_img_2, obj_1_img_3, obj_1_img_4, obj_1_img_5, obj_1_img_6, obj_1_img_7, obj_1_img_8, obj_2_img_1, obj_2_img_2, obj_2_img_3, obj_2_img_4, obj_2_img_5, obj_2_img_6, obj_3_img_1, obj_3_img_2, obj_3_img_3, obj_3_img_4, obj_3_img_5, obj_3_img_6, obj_3_img_7, obj_3_img_8, obj_3_img_9, ): return image, checkpoint, \ [obj_1_img_1, obj_1_img_2, obj_1_img_3, obj_1_img_4, obj_1_img_5, obj_1_img_6, obj_1_img_7, obj_1_img_8], \ [obj_2_img_1, obj_2_img_2, obj_2_img_3, obj_2_img_4, obj_2_img_5, obj_2_img_6], \ [obj_3_img_1, obj_3_img_2, obj_3_img_3, obj_3_img_4, obj_3_img_5, obj_3_img_6, obj_3_img_7, obj_3_img_8, obj_3_img_9], \ "Gallery setup successfully!" @spaces.GPU @torch.inference_mode() @torch.autocast(device_type="cuda", dtype=torch.bfloat16) def process( image_input, checkpoint_dropdown=CHECKPOINT_NAMES[0], obj_gallery_1 = [], obj_gallery_2 = [], obj_gallery_3 = [], ) -> Tuple[Image.Image, Image.Image]: # setup gallery - handle the case if one of the galleries is empty gallery_dict = {} if obj_gallery_1: gallery_dict["obj_000001"] = [Image.open(object_image[0]).convert("RGB") for object_image in obj_gallery_1] if obj_gallery_2: gallery_dict["obj_000002"] = [Image.open(object_image[0]).convert("RGB") for object_image in obj_gallery_2] if obj_gallery_3: gallery_dict["obj_000003"] = [Image.open(object_image[0]).convert("RGB") for object_image in obj_gallery_3] model = MASK_GENERATORS[checkpoint_dropdown] image = np.array(image_input.convert("RGB")) sam2_result = model.generate(image) detections = sv.Detections.from_sam(sam2_result) # prepare sam2 output for the format expected by DoUnseen sam2_masks, sam2_bboxes = dounseen_utils.reformat_sam2_output(sam2_result) segments = dounseen_utils.get_image_segments_from_binary_masks(image, sam2_masks, sam2_bboxes) DOUNSEEN_MODEL.update_gallery(gallery_dict) # single object #matched_query, score = DOUNSEEN_MODEL.find_object(segments, obj_name="obj_000001", method="max") #matched_query_ann_image = dounseen_utils.draw_segmented_image(image, [sam2_masks[matched_query]], [sam2_bboxes[matched_query]], classes_predictions=[0], classes_names=["obj_000001"]) # multiple objects class_predictions, class_scores= DOUNSEEN_MODEL.classify_all_objects(segments, threshold=0.6, multi_instance=False) filtered_class_predictions, filtered_class_scores, filtered_masks, filtered_bboxes = dounseen_utils.remove_unmatched_query_segments(class_predictions, class_scores, sam2_masks, sam2_bboxes) matched_query_ann_image = dounseen_utils.draw_segmented_image(image, filtered_masks, filtered_bboxes, filtered_class_predictions, classes_names=list(gallery_dict.keys())) # convert to PIL image matched_query_ann_image = Image.fromarray(matched_query_ann_image) return MASK_ANNOTATOR.annotate(image_input, detections), matched_query_ann_image with gr.Blocks() as demo: gr.Markdown("""
Dounseen Logo

Welcome to the Dounseen Package Demo! 👋

DoUnseen is a Python package for segmenting unseen objects—no training or fine-tuning required!

🔹 Use it as a standalone tool to directly identify novel objects.

🔹 Use it as an extension to the Segment Anything Model (SAM) or any other zero-shot segmentation method to segment unseen objects.

👇 Click on the example below to see how to prepare your inputs!

GitHub
""") with gr.Row(): with gr.Column(): # Frame for the "Inputs" tag gr.HTML("""

Inputs

""") # Frame for the "Main Scene Image" description with gr.Row(): gr.HTML("""

Main Scene Image

Upload an image containing all objects of interest you want to segment.

👉 Example: A photo of a table with various items like fruits, cups, or books.

""") # Frame for the image input with gr.Row(): image_input_component = gr.Image(type='pil', label='Scene Image') with gr.Row(): # Frame for the "Gallery" section gr.HTML("""

Object Galleries

DoUnseen can handle any number of objects, making it highly adaptable for various use cases.

👉 Example: If your main scene includes a cereal box, upload isolated images of the cereal box from different angles in one gallery.

""") with gr.Row(): object_gallery_1 = gr.Gallery(label="Object 1 Images", columns=20, height="auto", preview=True) with gr.Row(): object_gallery_2 = gr.Gallery(label="Object 2 Images", columns=20, height="auto", preview=True) with gr.Row(): object_gallery_3 = gr.Gallery(label="Object 3 Images", columns=20, height="auto", preview=True) with gr.Row(): # Frame for the "Checkpoint" section gr.HTML("""

SAM2 Settings

The SAM2 settings control how the automatic mask generator operates.

""") with gr.Row(): checkpoint_dropdown_component = gr.Dropdown( choices=CHECKPOINT_NAMES, value=CHECKPOINT_NAMES[0], label="Checkpoint", info="Select a SAM2 checkpoint to use.", interactive=True ) with gr.Row(): number_input = gr.Number(label="point_per_side", value=10) with gr.Row(): submit_button_component = gr.Button(value='Submit', variant='primary') with gr.Column(): # Outputs tag with darker color and dark gray background gr.HTML("""

Outputs

""") # Output 1: SAM2 Output with lighter gray fill with gr.Row(): gr.HTML("""

SAM2 Output

This output visualizes the masks produced by SAM2's automatic mask generator.

""") with gr.Row(): image_output_sam = gr.Image(type='pil', label='SAM2 Output') # Output 2: DoUnseen Output with lighter gray fill with gr.Row(): gr.HTML("""

DoUnseen Output

This output visualizes the results of DoUnseen, showing the segmented objects identified from the gallery.

""") with gr.Row(): image_output_dounseen = gr.Image(type='pil', label='DoUnseen Output') with gr.Row(): # example Row # make a list of hidden 6 gr.Image components to store the gallery images object_images_1 = [gr.Image(type='pil', label=f"Object 1 Image {i+1}", visible=False) for i in range(8)] object_images_2 = [gr.Image(type='pil', label=f"Object 2 Image {i+1}", visible=False) for i in range(6)] object_images_3 = [gr.Image(type='pil', label=f"Object 3 Image {i+1}", visible=False) for i in range(9)] # use a dummy text component to store the status of the gallery setup status = gr.Text("Dummy", label="Status", visible=False) # setup example gallery # clicking on the example will load the example images into the gallery in the expected format gr.Examples( fn=setup_example_gallery, examples=EXAMPLES, inputs=[image_input_component, checkpoint_dropdown_component] + object_images_1 + object_images_2 + object_images_3, outputs=[image_input_component, checkpoint_dropdown_component, object_gallery_1, object_gallery_2, object_gallery_3, status], cache_examples=False, run_on_click=True ) # This will trigger the process function after gallery is set up in the expected format status.change( fn=process, inputs=[image_input_component, checkpoint_dropdown_component, object_gallery_1, object_gallery_2, object_gallery_3], outputs=[image_output_sam, image_output_dounseen] ) submit_button_component.click( fn=process, inputs=[image_input_component, checkpoint_dropdown_component, object_gallery_1, object_gallery_2, object_gallery_3], outputs=[image_output_sam, image_output_dounseen] ) number_input.change(remake_sam2_mask_generators, inputs=[number_input]) demo.launch(debug=False, show_error=True)