dounseen / app.py
anas-gouda's picture
latest dounseen changes
0a1524a
from typing import Tuple
import gradio as gr
import spaces
import numpy as np
import supervision as sv
import torch
from PIL import Image
import dounseen.utils as dounseen_utils
from utils.models import load_sam2_models, make_sam2_mask_generators, CHECKPOINT_NAMES, load_dounseen_model
# TODO add presentation on YouTube and add link here
EXAMPLES = [
[
"https://raw.githubusercontent.com/AnasIbrahim/image_agnostic_segmentation/master/demo/scene_images/example.jpg",
"tiny",
#obj_000001
"https://raw.githubusercontent.com/AnasIbrahim/image_agnostic_segmentation/master/demo/objects_gallery/obj_000001/20240919_145038.jpg",
"https://raw.githubusercontent.com/AnasIbrahim/image_agnostic_segmentation/master/demo/objects_gallery/obj_000001/20240919_145042.jpg",
"https://raw.githubusercontent.com/AnasIbrahim/image_agnostic_segmentation/master/demo/objects_gallery/obj_000001/20240919_145045.jpg",
"https://raw.githubusercontent.com/AnasIbrahim/image_agnostic_segmentation/master/demo/objects_gallery/obj_000001/20240919_145048.jpg",
"https://raw.githubusercontent.com/AnasIbrahim/image_agnostic_segmentation/master/demo/objects_gallery/obj_000001/20240919_145052.jpg",
"https://raw.githubusercontent.com/AnasIbrahim/image_agnostic_segmentation/master/demo/objects_gallery/obj_000001/20240919_145055.jpg",
"https://raw.githubusercontent.com/AnasIbrahim/image_agnostic_segmentation/master/demo/objects_gallery/obj_000001/20240919_145423.jpg",
"https://raw.githubusercontent.com/AnasIbrahim/image_agnostic_segmentation/master/demo/objects_gallery/obj_000001/20240919_145427.jpg",
#obj_000002:
"https://raw.githubusercontent.com/AnasIbrahim/image_agnostic_segmentation/master/demo/objects_gallery/obj_000002/20240919_145450.jpg",
"https://raw.githubusercontent.com/AnasIbrahim/image_agnostic_segmentation/master/demo/objects_gallery/obj_000002/20240919_145454.jpg",
"https://raw.githubusercontent.com/AnasIbrahim/image_agnostic_segmentation/master/demo/objects_gallery/obj_000002/20240919_145500.jpg",
"https://raw.githubusercontent.com/AnasIbrahim/image_agnostic_segmentation/master/demo/objects_gallery/obj_000002/20240919_145502.jpg",
"https://raw.githubusercontent.com/AnasIbrahim/image_agnostic_segmentation/master/demo/objects_gallery/obj_000002/20240919_145506.jpg",
"https://raw.githubusercontent.com/AnasIbrahim/image_agnostic_segmentation/master/demo/objects_gallery/obj_000002/20240919_145757.jpg",
#obj_000003:
"https://raw.githubusercontent.com/AnasIbrahim/image_agnostic_segmentation/master/demo/objects_gallery/obj_000003/20240919_145519.jpg",
"https://raw.githubusercontent.com/AnasIbrahim/image_agnostic_segmentation/master/demo/objects_gallery/obj_000003/20240919_145522.jpg",
"https://raw.githubusercontent.com/AnasIbrahim/image_agnostic_segmentation/master/demo/objects_gallery/obj_000003/20240919_145526.jpg",
"https://raw.githubusercontent.com/AnasIbrahim/image_agnostic_segmentation/master/demo/objects_gallery/obj_000003/20240919_145529.jpg",
"https://raw.githubusercontent.com/AnasIbrahim/image_agnostic_segmentation/master/demo/objects_gallery/obj_000003/20240919_145719.jpg",
"https://raw.githubusercontent.com/AnasIbrahim/image_agnostic_segmentation/master/demo/objects_gallery/obj_000003/20240919_145727.jpg",
"https://raw.githubusercontent.com/AnasIbrahim/image_agnostic_segmentation/master/demo/objects_gallery/obj_000003/20240919_145730.jpg",
"https://raw.githubusercontent.com/AnasIbrahim/image_agnostic_segmentation/master/demo/objects_gallery/obj_000003/20240919_145736.jpg",
"https://raw.githubusercontent.com/AnasIbrahim/image_agnostic_segmentation/master/demo/objects_gallery/obj_000003/20240919_145744.jpg",
]
]
DEVICE = torch.device('cuda')
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
if torch.cuda.get_device_properties(0).major >= 8:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
MASK_ANNOTATOR = sv.MaskAnnotator(color_lookup=sv.ColorLookup.INDEX)
SAM2_models = load_sam2_models(device=DEVICE)
MASK_GENERATORS = make_sam2_mask_generators(SAM2_models)
DOUNSEEN_MODEL = load_dounseen_model(device=DEVICE)
def remake_sam2_mask_generators(points_per_side):
global MASK_GENERATORS
MASK_GENERATORS = make_sam2_mask_generators(SAM2_models, points_per_side)
def setup_example_gallery(
image, checkpoint,
obj_1_img_1, obj_1_img_2, obj_1_img_3, obj_1_img_4, obj_1_img_5, obj_1_img_6, obj_1_img_7, obj_1_img_8,
obj_2_img_1, obj_2_img_2, obj_2_img_3, obj_2_img_4, obj_2_img_5, obj_2_img_6,
obj_3_img_1, obj_3_img_2, obj_3_img_3, obj_3_img_4, obj_3_img_5, obj_3_img_6, obj_3_img_7, obj_3_img_8, obj_3_img_9,
):
return image, checkpoint, \
[obj_1_img_1, obj_1_img_2, obj_1_img_3, obj_1_img_4, obj_1_img_5, obj_1_img_6, obj_1_img_7, obj_1_img_8], \
[obj_2_img_1, obj_2_img_2, obj_2_img_3, obj_2_img_4, obj_2_img_5, obj_2_img_6], \
[obj_3_img_1, obj_3_img_2, obj_3_img_3, obj_3_img_4, obj_3_img_5, obj_3_img_6, obj_3_img_7, obj_3_img_8, obj_3_img_9], \
"Gallery setup successfully!"
@spaces.GPU
@torch.inference_mode()
@torch.autocast(device_type="cuda", dtype=torch.bfloat16)
def process(
image_input,
checkpoint_dropdown=CHECKPOINT_NAMES[0],
obj_gallery_1 = [],
obj_gallery_2 = [],
obj_gallery_3 = [],
) -> Tuple[Image.Image, Image.Image]:
# setup gallery - handle the case if one of the galleries is empty
gallery_dict = {}
if obj_gallery_1:
gallery_dict["obj_000001"] = [Image.open(object_image[0]).convert("RGB") for object_image in obj_gallery_1]
if obj_gallery_2:
gallery_dict["obj_000002"] = [Image.open(object_image[0]).convert("RGB") for object_image in obj_gallery_2]
if obj_gallery_3:
gallery_dict["obj_000003"] = [Image.open(object_image[0]).convert("RGB") for object_image in obj_gallery_3]
model = MASK_GENERATORS[checkpoint_dropdown]
image = np.array(image_input.convert("RGB"))
sam2_result = model.generate(image)
detections = sv.Detections.from_sam(sam2_result)
# prepare sam2 output for the format expected by DoUnseen
sam2_masks, sam2_bboxes = dounseen_utils.reformat_sam2_output(sam2_result)
segments = dounseen_utils.get_image_segments_from_binary_masks(image, sam2_masks, sam2_bboxes)
DOUNSEEN_MODEL.update_gallery(gallery_dict)
# single object
#matched_query, score = DOUNSEEN_MODEL.find_object(segments, obj_name="obj_000001", method="max")
#matched_query_ann_image = dounseen_utils.draw_segmented_image(image, [sam2_masks[matched_query]], [sam2_bboxes[matched_query]], classes_predictions=[0], classes_names=["obj_000001"])
# multiple objects
class_predictions, class_scores= DOUNSEEN_MODEL.classify_all_objects(segments, threshold=0.6, multi_instance=False)
filtered_class_predictions, filtered_class_scores, filtered_masks, filtered_bboxes = dounseen_utils.remove_unmatched_query_segments(class_predictions, class_scores, sam2_masks, sam2_bboxes)
matched_query_ann_image = dounseen_utils.draw_segmented_image(image, filtered_masks, filtered_bboxes, filtered_class_predictions, classes_names=list(gallery_dict.keys()))
# convert to PIL image
matched_query_ann_image = Image.fromarray(matched_query_ann_image)
return MASK_ANNOTATOR.annotate(image_input, detections), matched_query_ann_image
with gr.Blocks() as demo:
gr.Markdown("""
<div style="display: flex; flex-direction: column; align-items: center; text-align: center;">
<img src="https://raw.githubusercontent.com/AnasIbrahim/image_agnostic_segmentation/refs/heads/master/images/dounseen_logo_10.svg" alt="Dounseen Logo" style="width: 150px; margin-bottom: 10px;">
<h1>Welcome to the Dounseen Package Demo! 👋</h1>
<p><b>DoUnseen</b> is a Python package for segmenting <b>unseen objects</b>—no training or fine-tuning required!</p>
<p>🔹 Use it as a <b>standalone tool</b> to directly identify novel objects.</p>
<p>🔹 Use it <b>as an extension</b> to the <b>Segment Anything Model (SAM)</b> or any other <b>zero-shot segmentation method</b> to segment unseen objects.</p>
<p><b><span style="color: #32CD32;">👇 Click on the example below to see how to prepare your inputs!</span></b></p>
<div>
<a href="https://github.com/AnasIbrahim/image_agnostic_segmentation" target="_blank">
<img src="https://badges.aleen42.com/src/github.svg" alt="GitHub" style="display:inline-block; margin-top: 10px;">
</a>
</div>
</div>
""")
with gr.Row():
with gr.Column():
# Frame for the "Inputs" tag
gr.HTML("""
<div style="border: 2px solid #2E7D32; padding: 10px; border-radius: 10px; background-color: #C8E6C9;">
<h3 style="text-align: center; color: #2E7D32;">Inputs</h3>
</div>
""")
# Frame for the "Main Scene Image" description
with gr.Row():
gr.HTML("""
<div style="border: none; padding: 15px; border-radius: 10px; background-color: #eaf5e9;">
<h3 style="text-align: center; color: #4CAF50;">Main Scene Image</h3>
<p>Upload an image containing <b>all objects of interest</b> you want to segment.</p>
<ul>
<li><b>Flexible handling</b>: If an object is in the gallery but not in the image, it will simply not be detected.</li>
<li><b>Many objects? No problem</b>: The package works well even with multiple objects in the image.</li>
</ul>
<p><b>👉 Example:</b> A photo of a table with various items like fruits, cups, or books.</p>
</div>
""")
# Frame for the image input
with gr.Row():
image_input_component = gr.Image(type='pil', label='Scene Image')
with gr.Row():
# Frame for the "Gallery" section
gr.HTML("""
<div style="border: none; padding: 15px; border-radius: 10px; background-color: #eaf5e9;">
<h3 style="text-align: center; color: #4CAF50;">Object Galleries</h3>
<p><b>DoUnseen</b> can handle any number of objects, making it highly adaptable for various use cases.</p>
<ul>
<li><b>Demo limitation</b>: In this demo, you can upload up to <b>three galleries</b> of isolated images for segmentation. For fewer objects, leave the remaining galleries empty.</li>
<li><b>Object-specific</b>: Each gallery should include isolated images of a specific object, representing all its unique faces.</li>
</ul>
<p><b>👉 Example:</b> If your main scene includes a cereal box, upload isolated images of the cereal box from different angles in one gallery.</p>
</div>
""")
with gr.Row():
object_gallery_1 = gr.Gallery(label="Object 1 Images", columns=20, height="auto", preview=True)
with gr.Row():
object_gallery_2 = gr.Gallery(label="Object 2 Images", columns=20, height="auto", preview=True)
with gr.Row():
object_gallery_3 = gr.Gallery(label="Object 3 Images", columns=20, height="auto", preview=True)
with gr.Row():
# Frame for the "Checkpoint" section
gr.HTML("""
<div style="border: none; padding: 15px; border-radius: 10px; background-color: #eaf5e9;">
<h3 style="text-align: center; color: #4CAF50;">SAM2 Settings</h3>
<p>The SAM2 settings control how the automatic mask generator operates.</p>
<ul>
<li><b>Checkpoint:</b> Refers to the <b>model size</b> used by SAM2.</li>
<li><b>Points per Side:</b> Defines the grid density for mask generation. Higher values increase accuracy but require more computation.</li>
</ul>
</div>
""")
with gr.Row():
checkpoint_dropdown_component = gr.Dropdown(
choices=CHECKPOINT_NAMES,
value=CHECKPOINT_NAMES[0],
label="Checkpoint", info="Select a SAM2 checkpoint to use.",
interactive=True
)
with gr.Row():
number_input = gr.Number(label="point_per_side", value=10)
with gr.Row():
submit_button_component = gr.Button(value='Submit', variant='primary')
with gr.Column():
# Outputs tag with darker color and dark gray background
gr.HTML("""
<div style="border: 2px solid #555; padding: 10px; border-radius: 10px; background-color: #aaa;">
<h3 style="text-align: center; color: #333;">Outputs</h3>
</div>
""")
# Output 1: SAM2 Output with lighter gray fill
with gr.Row():
gr.HTML("""
<div style="border: none; padding: 15px; border-radius: 10px; background-color: #ddd;">
<h3 style="text-align: center; color: #555;">SAM2 Output</h3>
<p>This output visualizes the <b>masks produced by SAM2's automatic mask generator</b>.</p>
<ul>
<li><b>Adjustable Parameters:</b> If your object is not segmented by SAM2, consider increasing the <code>points_per_side</code> parameter to generate a denser grid of masks.</li>
<li><b>DoUnseen Compatibility:</b> The package can handle a large number of masks generated by SAM2, ensuring robust performance even in complex scenarios.</li>
</ul>
<div style="height: 5px;"></div> <!-- Explicit spacing -->
</div>
""")
with gr.Row():
image_output_sam = gr.Image(type='pil', label='SAM2 Output')
# Output 2: DoUnseen Output with lighter gray fill
with gr.Row():
gr.HTML("""
<div style="border: none; padding: 15px; border-radius: 10px; background-color: #ddd;">
<h3 style="text-align: center; color: #555;">DoUnseen Output</h3>
<p>This output visualizes the <b>results of DoUnseen</b>, showing the segmented objects identified from the gallery.</p>
<ul>
<li><b>Object handling:</b> If a gallery object is not in the scene, it will not be detected.</li>
<li><b>Multiple instances:</b> DoUnseen can detect multiple instances of each object, but in this demo, the settings are fixed to detect only one instance per gallery object.</li>
</ul>
</div>
""")
with gr.Row():
image_output_dounseen = gr.Image(type='pil', label='DoUnseen Output')
with gr.Row(): # example Row
# make a list of hidden 6 gr.Image components to store the gallery images
object_images_1 = [gr.Image(type='pil', label=f"Object 1 Image {i+1}", visible=False) for i in range(8)]
object_images_2 = [gr.Image(type='pil', label=f"Object 2 Image {i+1}", visible=False) for i in range(6)]
object_images_3 = [gr.Image(type='pil', label=f"Object 3 Image {i+1}", visible=False) for i in range(9)]
# use a dummy text component to store the status of the gallery setup
status = gr.Text("Dummy", label="Status", visible=False)
# setup example gallery
# clicking on the example will load the example images into the gallery in the expected format
gr.Examples(
fn=setup_example_gallery,
examples=EXAMPLES,
inputs=[image_input_component, checkpoint_dropdown_component] + object_images_1 + object_images_2 + object_images_3,
outputs=[image_input_component, checkpoint_dropdown_component, object_gallery_1, object_gallery_2, object_gallery_3, status],
cache_examples=False,
run_on_click=True
)
# This will trigger the process function after gallery is set up in the expected format
status.change(
fn=process,
inputs=[image_input_component, checkpoint_dropdown_component, object_gallery_1, object_gallery_2, object_gallery_3],
outputs=[image_output_sam, image_output_dounseen]
)
submit_button_component.click(
fn=process,
inputs=[image_input_component, checkpoint_dropdown_component, object_gallery_1, object_gallery_2, object_gallery_3],
outputs=[image_output_sam, image_output_dounseen]
)
number_input.change(remake_sam2_mask_generators, inputs=[number_input])
demo.launch(debug=False, show_error=True)