Spaces:
Sleeping
Sleeping
| from typing import Tuple | |
| import gradio as gr | |
| import spaces | |
| import numpy as np | |
| import supervision as sv | |
| import torch | |
| from PIL import Image | |
| import dounseen.utils as dounseen_utils | |
| from utils.models import load_sam2_models, make_sam2_mask_generators, CHECKPOINT_NAMES, load_dounseen_model | |
| # TODO add presentation on YouTube and add link here | |
| EXAMPLES = [ | |
| [ | |
| "https://raw.githubusercontent.com/AnasIbrahim/image_agnostic_segmentation/master/demo/scene_images/example.jpg", | |
| "tiny", | |
| #obj_000001 | |
| "https://raw.githubusercontent.com/AnasIbrahim/image_agnostic_segmentation/master/demo/objects_gallery/obj_000001/20240919_145038.jpg", | |
| "https://raw.githubusercontent.com/AnasIbrahim/image_agnostic_segmentation/master/demo/objects_gallery/obj_000001/20240919_145042.jpg", | |
| "https://raw.githubusercontent.com/AnasIbrahim/image_agnostic_segmentation/master/demo/objects_gallery/obj_000001/20240919_145045.jpg", | |
| "https://raw.githubusercontent.com/AnasIbrahim/image_agnostic_segmentation/master/demo/objects_gallery/obj_000001/20240919_145048.jpg", | |
| "https://raw.githubusercontent.com/AnasIbrahim/image_agnostic_segmentation/master/demo/objects_gallery/obj_000001/20240919_145052.jpg", | |
| "https://raw.githubusercontent.com/AnasIbrahim/image_agnostic_segmentation/master/demo/objects_gallery/obj_000001/20240919_145055.jpg", | |
| "https://raw.githubusercontent.com/AnasIbrahim/image_agnostic_segmentation/master/demo/objects_gallery/obj_000001/20240919_145423.jpg", | |
| "https://raw.githubusercontent.com/AnasIbrahim/image_agnostic_segmentation/master/demo/objects_gallery/obj_000001/20240919_145427.jpg", | |
| #obj_000002: | |
| "https://raw.githubusercontent.com/AnasIbrahim/image_agnostic_segmentation/master/demo/objects_gallery/obj_000002/20240919_145450.jpg", | |
| "https://raw.githubusercontent.com/AnasIbrahim/image_agnostic_segmentation/master/demo/objects_gallery/obj_000002/20240919_145454.jpg", | |
| "https://raw.githubusercontent.com/AnasIbrahim/image_agnostic_segmentation/master/demo/objects_gallery/obj_000002/20240919_145500.jpg", | |
| "https://raw.githubusercontent.com/AnasIbrahim/image_agnostic_segmentation/master/demo/objects_gallery/obj_000002/20240919_145502.jpg", | |
| "https://raw.githubusercontent.com/AnasIbrahim/image_agnostic_segmentation/master/demo/objects_gallery/obj_000002/20240919_145506.jpg", | |
| "https://raw.githubusercontent.com/AnasIbrahim/image_agnostic_segmentation/master/demo/objects_gallery/obj_000002/20240919_145757.jpg", | |
| #obj_000003: | |
| "https://raw.githubusercontent.com/AnasIbrahim/image_agnostic_segmentation/master/demo/objects_gallery/obj_000003/20240919_145519.jpg", | |
| "https://raw.githubusercontent.com/AnasIbrahim/image_agnostic_segmentation/master/demo/objects_gallery/obj_000003/20240919_145522.jpg", | |
| "https://raw.githubusercontent.com/AnasIbrahim/image_agnostic_segmentation/master/demo/objects_gallery/obj_000003/20240919_145526.jpg", | |
| "https://raw.githubusercontent.com/AnasIbrahim/image_agnostic_segmentation/master/demo/objects_gallery/obj_000003/20240919_145529.jpg", | |
| "https://raw.githubusercontent.com/AnasIbrahim/image_agnostic_segmentation/master/demo/objects_gallery/obj_000003/20240919_145719.jpg", | |
| "https://raw.githubusercontent.com/AnasIbrahim/image_agnostic_segmentation/master/demo/objects_gallery/obj_000003/20240919_145727.jpg", | |
| "https://raw.githubusercontent.com/AnasIbrahim/image_agnostic_segmentation/master/demo/objects_gallery/obj_000003/20240919_145730.jpg", | |
| "https://raw.githubusercontent.com/AnasIbrahim/image_agnostic_segmentation/master/demo/objects_gallery/obj_000003/20240919_145736.jpg", | |
| "https://raw.githubusercontent.com/AnasIbrahim/image_agnostic_segmentation/master/demo/objects_gallery/obj_000003/20240919_145744.jpg", | |
| ] | |
| ] | |
| DEVICE = torch.device('cuda') | |
| torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__() | |
| if torch.cuda.get_device_properties(0).major >= 8: | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| torch.backends.cudnn.allow_tf32 = True | |
| MASK_ANNOTATOR = sv.MaskAnnotator(color_lookup=sv.ColorLookup.INDEX) | |
| SAM2_models = load_sam2_models(device=DEVICE) | |
| MASK_GENERATORS = make_sam2_mask_generators(SAM2_models) | |
| DOUNSEEN_MODEL = load_dounseen_model(device=DEVICE) | |
| def remake_sam2_mask_generators(points_per_side): | |
| global MASK_GENERATORS | |
| MASK_GENERATORS = make_sam2_mask_generators(SAM2_models, points_per_side) | |
| def setup_example_gallery( | |
| image, checkpoint, | |
| obj_1_img_1, obj_1_img_2, obj_1_img_3, obj_1_img_4, obj_1_img_5, obj_1_img_6, obj_1_img_7, obj_1_img_8, | |
| obj_2_img_1, obj_2_img_2, obj_2_img_3, obj_2_img_4, obj_2_img_5, obj_2_img_6, | |
| obj_3_img_1, obj_3_img_2, obj_3_img_3, obj_3_img_4, obj_3_img_5, obj_3_img_6, obj_3_img_7, obj_3_img_8, obj_3_img_9, | |
| ): | |
| return image, checkpoint, \ | |
| [obj_1_img_1, obj_1_img_2, obj_1_img_3, obj_1_img_4, obj_1_img_5, obj_1_img_6, obj_1_img_7, obj_1_img_8], \ | |
| [obj_2_img_1, obj_2_img_2, obj_2_img_3, obj_2_img_4, obj_2_img_5, obj_2_img_6], \ | |
| [obj_3_img_1, obj_3_img_2, obj_3_img_3, obj_3_img_4, obj_3_img_5, obj_3_img_6, obj_3_img_7, obj_3_img_8, obj_3_img_9], \ | |
| "Gallery setup successfully!" | |
| def process( | |
| image_input, | |
| checkpoint_dropdown=CHECKPOINT_NAMES[0], | |
| obj_gallery_1 = [], | |
| obj_gallery_2 = [], | |
| obj_gallery_3 = [], | |
| ) -> Tuple[Image.Image, Image.Image]: | |
| # setup gallery - handle the case if one of the galleries is empty | |
| gallery_dict = {} | |
| if obj_gallery_1: | |
| gallery_dict["obj_000001"] = [Image.open(object_image[0]).convert("RGB") for object_image in obj_gallery_1] | |
| if obj_gallery_2: | |
| gallery_dict["obj_000002"] = [Image.open(object_image[0]).convert("RGB") for object_image in obj_gallery_2] | |
| if obj_gallery_3: | |
| gallery_dict["obj_000003"] = [Image.open(object_image[0]).convert("RGB") for object_image in obj_gallery_3] | |
| model = MASK_GENERATORS[checkpoint_dropdown] | |
| image = np.array(image_input.convert("RGB")) | |
| sam2_result = model.generate(image) | |
| detections = sv.Detections.from_sam(sam2_result) | |
| # prepare sam2 output for the format expected by DoUnseen | |
| sam2_masks, sam2_bboxes = dounseen_utils.reformat_sam2_output(sam2_result) | |
| segments = dounseen_utils.get_image_segments_from_binary_masks(image, sam2_masks, sam2_bboxes) | |
| DOUNSEEN_MODEL.update_gallery(gallery_dict) | |
| # single object | |
| #matched_query, score = DOUNSEEN_MODEL.find_object(segments, obj_name="obj_000001", method="max") | |
| #matched_query_ann_image = dounseen_utils.draw_segmented_image(image, [sam2_masks[matched_query]], [sam2_bboxes[matched_query]], classes_predictions=[0], classes_names=["obj_000001"]) | |
| # multiple objects | |
| class_predictions, class_scores= DOUNSEEN_MODEL.classify_all_objects(segments, threshold=0.6, multi_instance=False) | |
| filtered_class_predictions, filtered_class_scores, filtered_masks, filtered_bboxes = dounseen_utils.remove_unmatched_query_segments(class_predictions, class_scores, sam2_masks, sam2_bboxes) | |
| matched_query_ann_image = dounseen_utils.draw_segmented_image(image, filtered_masks, filtered_bboxes, filtered_class_predictions, classes_names=list(gallery_dict.keys())) | |
| # convert to PIL image | |
| matched_query_ann_image = Image.fromarray(matched_query_ann_image) | |
| return MASK_ANNOTATOR.annotate(image_input, detections), matched_query_ann_image | |
| with gr.Blocks() as demo: | |
| gr.Markdown(""" | |
| <div style="display: flex; flex-direction: column; align-items: center; text-align: center;"> | |
| <img src="https://raw.githubusercontent.com/AnasIbrahim/image_agnostic_segmentation/refs/heads/master/images/dounseen_logo_10.svg" alt="Dounseen Logo" style="width: 150px; margin-bottom: 10px;"> | |
| <h1>Welcome to the Dounseen Package Demo! 👋</h1> | |
| <p><b>DoUnseen</b> is a Python package for segmenting <b>unseen objects</b>—no training or fine-tuning required!</p> | |
| <p>🔹 Use it as a <b>standalone tool</b> to directly identify novel objects.</p> | |
| <p>🔹 Use it <b>as an extension</b> to the <b>Segment Anything Model (SAM)</b> or any other <b>zero-shot segmentation method</b> to segment unseen objects.</p> | |
| <p><b><span style="color: #32CD32;">👇 Click on the example below to see how to prepare your inputs!</span></b></p> | |
| <div> | |
| <a href="https://github.com/AnasIbrahim/image_agnostic_segmentation" target="_blank"> | |
| <img src="https://badges.aleen42.com/src/github.svg" alt="GitHub" style="display:inline-block; margin-top: 10px;"> | |
| </a> | |
| </div> | |
| </div> | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| # Frame for the "Inputs" tag | |
| gr.HTML(""" | |
| <div style="border: 2px solid #2E7D32; padding: 10px; border-radius: 10px; background-color: #C8E6C9;"> | |
| <h3 style="text-align: center; color: #2E7D32;">Inputs</h3> | |
| </div> | |
| """) | |
| # Frame for the "Main Scene Image" description | |
| with gr.Row(): | |
| gr.HTML(""" | |
| <div style="border: none; padding: 15px; border-radius: 10px; background-color: #eaf5e9;"> | |
| <h3 style="text-align: center; color: #4CAF50;">Main Scene Image</h3> | |
| <p>Upload an image containing <b>all objects of interest</b> you want to segment.</p> | |
| <ul> | |
| <li><b>Flexible handling</b>: If an object is in the gallery but not in the image, it will simply not be detected.</li> | |
| <li><b>Many objects? No problem</b>: The package works well even with multiple objects in the image.</li> | |
| </ul> | |
| <p><b>👉 Example:</b> A photo of a table with various items like fruits, cups, or books.</p> | |
| </div> | |
| """) | |
| # Frame for the image input | |
| with gr.Row(): | |
| image_input_component = gr.Image(type='pil', label='Scene Image') | |
| with gr.Row(): | |
| # Frame for the "Gallery" section | |
| gr.HTML(""" | |
| <div style="border: none; padding: 15px; border-radius: 10px; background-color: #eaf5e9;"> | |
| <h3 style="text-align: center; color: #4CAF50;">Object Galleries</h3> | |
| <p><b>DoUnseen</b> can handle any number of objects, making it highly adaptable for various use cases.</p> | |
| <ul> | |
| <li><b>Demo limitation</b>: In this demo, you can upload up to <b>three galleries</b> of isolated images for segmentation. For fewer objects, leave the remaining galleries empty.</li> | |
| <li><b>Object-specific</b>: Each gallery should include isolated images of a specific object, representing all its unique faces.</li> | |
| </ul> | |
| <p><b>👉 Example:</b> If your main scene includes a cereal box, upload isolated images of the cereal box from different angles in one gallery.</p> | |
| </div> | |
| """) | |
| with gr.Row(): | |
| object_gallery_1 = gr.Gallery(label="Object 1 Images", columns=20, height="auto", preview=True) | |
| with gr.Row(): | |
| object_gallery_2 = gr.Gallery(label="Object 2 Images", columns=20, height="auto", preview=True) | |
| with gr.Row(): | |
| object_gallery_3 = gr.Gallery(label="Object 3 Images", columns=20, height="auto", preview=True) | |
| with gr.Row(): | |
| # Frame for the "Checkpoint" section | |
| gr.HTML(""" | |
| <div style="border: none; padding: 15px; border-radius: 10px; background-color: #eaf5e9;"> | |
| <h3 style="text-align: center; color: #4CAF50;">SAM2 Settings</h3> | |
| <p>The SAM2 settings control how the automatic mask generator operates.</p> | |
| <ul> | |
| <li><b>Checkpoint:</b> Refers to the <b>model size</b> used by SAM2.</li> | |
| <li><b>Points per Side:</b> Defines the grid density for mask generation. Higher values increase accuracy but require more computation.</li> | |
| </ul> | |
| </div> | |
| """) | |
| with gr.Row(): | |
| checkpoint_dropdown_component = gr.Dropdown( | |
| choices=CHECKPOINT_NAMES, | |
| value=CHECKPOINT_NAMES[0], | |
| label="Checkpoint", info="Select a SAM2 checkpoint to use.", | |
| interactive=True | |
| ) | |
| with gr.Row(): | |
| number_input = gr.Number(label="point_per_side", value=10) | |
| with gr.Row(): | |
| submit_button_component = gr.Button(value='Submit', variant='primary') | |
| with gr.Column(): | |
| # Outputs tag with darker color and dark gray background | |
| gr.HTML(""" | |
| <div style="border: 2px solid #555; padding: 10px; border-radius: 10px; background-color: #aaa;"> | |
| <h3 style="text-align: center; color: #333;">Outputs</h3> | |
| </div> | |
| """) | |
| # Output 1: SAM2 Output with lighter gray fill | |
| with gr.Row(): | |
| gr.HTML(""" | |
| <div style="border: none; padding: 15px; border-radius: 10px; background-color: #ddd;"> | |
| <h3 style="text-align: center; color: #555;">SAM2 Output</h3> | |
| <p>This output visualizes the <b>masks produced by SAM2's automatic mask generator</b>.</p> | |
| <ul> | |
| <li><b>Adjustable Parameters:</b> If your object is not segmented by SAM2, consider increasing the <code>points_per_side</code> parameter to generate a denser grid of masks.</li> | |
| <li><b>DoUnseen Compatibility:</b> The package can handle a large number of masks generated by SAM2, ensuring robust performance even in complex scenarios.</li> | |
| </ul> | |
| <div style="height: 5px;"></div> <!-- Explicit spacing --> | |
| </div> | |
| """) | |
| with gr.Row(): | |
| image_output_sam = gr.Image(type='pil', label='SAM2 Output') | |
| # Output 2: DoUnseen Output with lighter gray fill | |
| with gr.Row(): | |
| gr.HTML(""" | |
| <div style="border: none; padding: 15px; border-radius: 10px; background-color: #ddd;"> | |
| <h3 style="text-align: center; color: #555;">DoUnseen Output</h3> | |
| <p>This output visualizes the <b>results of DoUnseen</b>, showing the segmented objects identified from the gallery.</p> | |
| <ul> | |
| <li><b>Object handling:</b> If a gallery object is not in the scene, it will not be detected.</li> | |
| <li><b>Multiple instances:</b> DoUnseen can detect multiple instances of each object, but in this demo, the settings are fixed to detect only one instance per gallery object.</li> | |
| </ul> | |
| </div> | |
| """) | |
| with gr.Row(): | |
| image_output_dounseen = gr.Image(type='pil', label='DoUnseen Output') | |
| with gr.Row(): # example Row | |
| # make a list of hidden 6 gr.Image components to store the gallery images | |
| object_images_1 = [gr.Image(type='pil', label=f"Object 1 Image {i+1}", visible=False) for i in range(8)] | |
| object_images_2 = [gr.Image(type='pil', label=f"Object 2 Image {i+1}", visible=False) for i in range(6)] | |
| object_images_3 = [gr.Image(type='pil', label=f"Object 3 Image {i+1}", visible=False) for i in range(9)] | |
| # use a dummy text component to store the status of the gallery setup | |
| status = gr.Text("Dummy", label="Status", visible=False) | |
| # setup example gallery | |
| # clicking on the example will load the example images into the gallery in the expected format | |
| gr.Examples( | |
| fn=setup_example_gallery, | |
| examples=EXAMPLES, | |
| inputs=[image_input_component, checkpoint_dropdown_component] + object_images_1 + object_images_2 + object_images_3, | |
| outputs=[image_input_component, checkpoint_dropdown_component, object_gallery_1, object_gallery_2, object_gallery_3, status], | |
| cache_examples=False, | |
| run_on_click=True | |
| ) | |
| # This will trigger the process function after gallery is set up in the expected format | |
| status.change( | |
| fn=process, | |
| inputs=[image_input_component, checkpoint_dropdown_component, object_gallery_1, object_gallery_2, object_gallery_3], | |
| outputs=[image_output_sam, image_output_dounseen] | |
| ) | |
| submit_button_component.click( | |
| fn=process, | |
| inputs=[image_input_component, checkpoint_dropdown_component, object_gallery_1, object_gallery_2, object_gallery_3], | |
| outputs=[image_output_sam, image_output_dounseen] | |
| ) | |
| number_input.change(remake_sam2_mask_generators, inputs=[number_input]) | |
| demo.launch(debug=False, show_error=True) | |