Spaces:
Running
on
Zero
Running
on
Zero
| import sys | |
| # Mock audio modules to avoid installing them | |
| sys.modules["audioop"] = type("audioop", (), {"__file__": ""})() | |
| sys.modules["pyaudioop"] = type("pyaudioop", (), {"__file__": ""})() | |
| import torch | |
| import gradio as gr | |
| import supervision as sv | |
| import spaces | |
| from PIL import Image | |
| from transformers import AutoProcessor, Owlv2ForObjectDetection, Owlv2Processor | |
| from transformers.models.owlv2.modeling_owlv2 import Owlv2ImageGuidedObjectDetectionOutput, center_to_corners_format, box_iou | |
| #from transformers.models.owlv2.image_processing_owlv2 | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| def init_model(model_id): | |
| processor = AutoProcessor.from_pretrained(model_id) | |
| model = Owlv2ForObjectDetection.from_pretrained(model_id) | |
| model.eval() | |
| model.to(DEVICE) | |
| image_size = tuple(processor.image_processor.size.values()) | |
| image_mean = torch.tensor( | |
| processor.image_processor.image_mean, device=DEVICE | |
| ).view(1, 3, 1, 1) | |
| image_std = torch.tensor( | |
| processor.image_processor.image_std, device=DEVICE | |
| ).view(1, 3, 1, 1) | |
| return processor, model, image_size, image_mean, image_std | |
| def inference(prompts, target_image, model_id, conf_thresh, iou_thresh, prompt_type): | |
| processor, model, image_size, image_mean, image_std = init_model(model_id) | |
| annotated_image_my = None | |
| annotated_image_hf = None | |
| annotated_prompt_image = None | |
| if prompt_type == "Text": | |
| inputs = processor( | |
| images=target_image, | |
| text=prompts["texts"], | |
| return_tensors="pt" | |
| ).to(DEVICE) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| target_sizes = torch.tensor([target_image.size[::-1]]) | |
| result = processor.post_process_grounded_object_detection( | |
| outputs=outputs, | |
| target_sizes=target_sizes, | |
| threshold=conf_thresh | |
| )[0] | |
| class_names = {k: v for k, v in enumerate(prompts["texts"])} | |
| # annotate the target image | |
| annotated_image_hf = annotate_image(result, class_names, target_image) | |
| elif prompt_type == "Visual": | |
| prompt_image = prompts["images"] | |
| inputs = processor( | |
| images=target_image, | |
| query_images=prompt_image, | |
| return_tensors="pt" | |
| ).to(DEVICE) | |
| with torch.no_grad(): | |
| query_feature_map = model.image_embedder(pixel_values=inputs.query_pixel_values)[0] | |
| feature_map = model.image_embedder(pixel_values=inputs.pixel_values)[0] | |
| batch_size, num_patches_height, num_patches_width, hidden_dim = feature_map.shape | |
| image_feats = torch.reshape(feature_map, (batch_size, num_patches_height * num_patches_width, hidden_dim)) | |
| batch_size, num_patches_height, num_patches_width, hidden_dim = query_feature_map.shape | |
| query_image_feats = torch.reshape(query_feature_map, (batch_size, num_patches_height * num_patches_width, hidden_dim)) | |
| # Select using hf method | |
| query_embeds2, box_indices, pred_boxes = model.embed_image_query( | |
| query_image_features=query_image_feats, | |
| query_feature_map=query_feature_map | |
| ) | |
| # Select top object from prompt image * iou | |
| objectnesses = torch.sigmoid(model.objectness_predictor(query_image_feats)) | |
| _, source_class_embeddings = model.class_predictor(query_image_feats) | |
| # identify the box that covers only the prompt image area excluding padding | |
| pw, ph = prompt_image.size | |
| max_side = max(pw, ph) | |
| each_query_box = torch.tensor([[0, 0, pw/max_side, ph/max_side]], device=DEVICE) | |
| pred_boxes_as_corners = center_to_corners_format(pred_boxes) | |
| each_query_pred_boxes = pred_boxes_as_corners[0] | |
| ious, _ = box_iou(each_query_box, each_query_pred_boxes) | |
| comb_score = objectnesses * ious | |
| top_obj_idx = torch.argmax(comb_score, dim=-1) | |
| query_embeds = source_class_embeddings[0][top_obj_idx] | |
| # Predict object boxes | |
| target_pred_boxes = model.box_predictor(image_feats, feature_map) | |
| # Predict for prompt: my method | |
| (pred_logits, class_embeds) = model.class_predictor(image_feats=image_feats, query_embeds=query_embeds) | |
| outputs = Owlv2ImageGuidedObjectDetectionOutput( | |
| logits=pred_logits, | |
| target_pred_boxes=target_pred_boxes, | |
| ) | |
| # Post-process results | |
| target_sizes = torch.tensor([target_image.size[::-1]]) | |
| result = processor.post_process_image_guided_detection( | |
| outputs=outputs, | |
| target_sizes=target_sizes, | |
| threshold=conf_thresh, | |
| nms_threshold=iou_thresh | |
| )[0] | |
| # prepare for supervision: add 0 label for all boxes | |
| result['labels'] = torch.zeros(len(result['boxes']), dtype=torch.int64) | |
| class_names = {0: "object"} | |
| # annotate the target image | |
| annotated_image_my = annotate_image(result, class_names, pad_to_square(target_image)) | |
| # Predict for prompt: hf method | |
| (pred_logits, class_embeds) = model.class_predictor(image_feats=image_feats, query_embeds=query_embeds2) | |
| # Predict object boxes | |
| outputs = Owlv2ImageGuidedObjectDetectionOutput( | |
| logits=pred_logits, | |
| target_pred_boxes=target_pred_boxes, | |
| ) | |
| # Post-process results | |
| target_sizes = torch.tensor([target_image.size[::-1]]) | |
| result = processor.post_process_image_guided_detection( | |
| outputs=outputs, | |
| target_sizes=target_sizes, | |
| threshold=conf_thresh, | |
| nms_threshold=iou_thresh | |
| )[0] | |
| # prepare for supervision: add 0 label for all boxes | |
| result['labels'] = torch.zeros(len(result['boxes']), dtype=torch.int64) | |
| class_names = {0: "object"} | |
| # annotate the target image | |
| annotated_image_hf = annotate_image(result, class_names, pad_to_square(target_image)) | |
| # Render selected prompt embedding | |
| query_pred_boxes = pred_boxes[0, [top_obj_idx, box_indices[0]]].unsqueeze(0) | |
| query_logits = torch.reshape(objectnesses[0, [top_obj_idx, box_indices[0]]], (1, 2, 1)) | |
| query_outputs = Owlv2ImageGuidedObjectDetectionOutput( | |
| logits=query_logits, | |
| target_pred_boxes=query_pred_boxes, | |
| ) | |
| query_result = processor.post_process_image_guided_detection( | |
| outputs=query_outputs, | |
| target_sizes=torch.tensor([prompt_image.size[::-1]]), | |
| threshold=0.0, | |
| nms_threshold=1.0 | |
| )[0] | |
| query_result['labels'] = torch.Tensor([0, 1]) | |
| # Annotate the prompt image | |
| query_class_names = {0: "my", 1: "hf"} | |
| # annotate the prompt image | |
| annotated_prompt_image = annotate_image(query_result, query_class_names, pad_to_square(prompt_image)) | |
| return annotated_image_my, annotated_image_hf, annotated_prompt_image | |
| def annotate_image(result, class_names, image): | |
| detections = sv.Detections.from_transformers(result, class_names) | |
| resolution_wh = image.size | |
| thickness = sv.calculate_optimal_line_thickness(resolution_wh=resolution_wh) | |
| text_scale = sv.calculate_optimal_text_scale(resolution_wh=resolution_wh) | |
| labels = [ | |
| f"{class_name} {confidence:.2f}" | |
| for class_name, confidence | |
| in zip(detections['class_name'], detections.confidence) | |
| ] | |
| annotated_image = image.copy() | |
| annotated_image = sv.BoxAnnotator(color_lookup=sv.ColorLookup.INDEX, thickness=thickness).annotate( | |
| scene=annotated_image, detections=detections) | |
| annotated_image = sv.LabelAnnotator(color_lookup=sv.ColorLookup.INDEX, text_scale=text_scale, smart_position=True).annotate( | |
| scene=annotated_image, detections=detections, labels=labels) | |
| return annotated_image | |
| def pad_to_square(image, background_color=(128, 128, 128)): | |
| width, height = image.size | |
| max_side = max(width, height) | |
| result = Image.new(image.mode, (max_side, max_side), background_color) | |
| result.paste(image, (0, 0)) | |
| return result | |
| def app(): | |
| with gr.Blocks(): | |
| with gr.Row(): | |
| with gr.Column(): | |
| target_image = gr.Image(type="pil", label="Target Image", visible=True, interactive=True) | |
| detect_button = gr.Button(value="Detect Objects") | |
| prompt_type = gr.Textbox(value='Visual', visible=False) # Default prompt type | |
| with gr.Tab("Visual") as visual_tab: | |
| prompt_image = gr.Image(type="pil", label="Prompt Image", visible=True, interactive=True) | |
| with gr.Tab("Text") as text_tab: | |
| texts = gr.Textbox(label="Input Texts", value='', placeholder='person,bus', visible=True, interactive=True) | |
| model_id = gr.Dropdown( | |
| label="Model", | |
| choices=[ | |
| "google/owlv2-base-patch16-ensemble", | |
| "google/owlv2-large-patch14-ensemble" | |
| ], | |
| value="google/owlv2-base-patch16-ensemble", | |
| ) | |
| conf_thresh = gr.Slider( | |
| label="Confidence Threshold", | |
| minimum=0.0, | |
| maximum=1.0, | |
| step=0.05, | |
| value=0.25, | |
| ) | |
| iou_thresh = gr.Slider( | |
| label="NSM Threshold", | |
| minimum=0.0, | |
| maximum=1.0, | |
| step=0.05, | |
| value=0.70, | |
| ) | |
| with gr.Column(): | |
| output_image_hf_gr = gr.Group() | |
| with output_image_hf_gr: | |
| gr.Markdown("### Annotated Image (HF default)") | |
| output_image_hf = gr.Image(type="numpy", visible=True, show_label=False) | |
| output_image_my_gr = gr.Group() | |
| with output_image_my_gr: | |
| gr.Markdown("### Annotated Image (Objectness Γ IoU variant)") | |
| output_image_my = gr.Image(type="numpy", visible=True, show_label=False) | |
| annotated_prompt_image_gr = gr.Group() | |
| with annotated_prompt_image_gr: | |
| gr.Markdown("### Prompt Image with Selected Embeddings and Objectness Score") | |
| annotated_prompt_image = gr.Image(type="numpy", visible=True, show_label=False) | |
| visual_tab.select( | |
| fn=lambda: ("Visual", gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)), | |
| inputs=None, | |
| outputs=[prompt_type, prompt_image, output_image_my_gr, annotated_prompt_image_gr] | |
| ) | |
| text_tab.select( | |
| fn=lambda: ("Text", gr.update(value=None, visible=False), gr.update(visible=False), gr.update(visible=False)), | |
| inputs=None, | |
| outputs=[prompt_type, prompt_image, output_image_my_gr, annotated_prompt_image_gr] | |
| ) | |
| def run_inference(prompt_image, target_image, texts, model_id, conf_thresh, iou_thresh, prompt_type): | |
| # add text/built-in prompts | |
| if prompt_type == "Text": | |
| texts = [text.strip() for text in texts.split(',')] | |
| prompts = { | |
| "texts": texts | |
| } | |
| # add visual prompt | |
| elif prompt_type == "Visual": | |
| prompts = { | |
| "images": prompt_image, | |
| } | |
| return inference(prompts, target_image, model_id, conf_thresh, iou_thresh, prompt_type) | |
| detect_button.click( | |
| fn=run_inference, | |
| inputs=[prompt_image, target_image, texts, model_id, conf_thresh, iou_thresh, prompt_type], | |
| outputs=[output_image_my, output_image_hf, annotated_prompt_image], | |
| ) | |
| ###################### Examples ########################## | |
| image_examples_list = [[ | |
| "test-data/target1.jpg", | |
| "test-data/prompt1.jpg", | |
| "google/owlv2-base-patch16-ensemble", | |
| 0.9, | |
| 0.3, | |
| ], | |
| [ | |
| "test-data/target2.jpg", | |
| "test-data/prompt2.jpg", | |
| "google/owlv2-base-patch16-ensemble", | |
| 0.9, | |
| 0.3, | |
| ], | |
| [ | |
| "test-data/target3.jpg", | |
| "test-data/prompt3.jpg", | |
| "google/owlv2-base-patch16-ensemble", | |
| 0.9, | |
| 0.3, | |
| ], | |
| [ | |
| "test-data/target4.jpg", | |
| "test-data/prompt4.jpg", | |
| "google/owlv2-base-patch16-ensemble", | |
| 0.9, | |
| 0.3, | |
| ], | |
| [ | |
| "test-data/target5.jpg", | |
| "test-data/prompt5.jpg", | |
| "google/owlv2-base-patch16-ensemble", | |
| 0.9, | |
| 0.3, | |
| ], | |
| [ | |
| "test-data/target6.jpg", | |
| "test-data/prompt6.jpg", | |
| "google/owlv2-base-patch16-ensemble", | |
| 0.9, | |
| 0.3, | |
| ] | |
| ] | |
| text_examples = gr.Examples( | |
| examples=[[ | |
| "test-data/target1.jpg", | |
| "logo", | |
| "google/owlv2-base-patch16-ensemble", | |
| 0.3], | |
| [ | |
| "test-data/target2.jpg", | |
| "cat,remote", | |
| "google/owlv2-base-patch16-ensemble", | |
| 0.3], | |
| [ | |
| "test-data/target3.jpg", | |
| "frog,spider,lizard", | |
| "google/owlv2-base-patch16-ensemble", | |
| 0.3], | |
| [ | |
| "test-data/target4.jpg", | |
| "cat", | |
| "google/owlv2-base-patch16-ensemble", | |
| 0.3 | |
| ], | |
| [ | |
| "test-data/target5.jpg", | |
| "lemon,straw", | |
| "google/owlv2-base-patch16-ensemble", | |
| 0.3 | |
| ], | |
| [ | |
| "test-data/target6.jpg", | |
| "beer logo", | |
| "google/owlv2-base-patch16-ensemble", | |
| 0.3 | |
| ] | |
| ], | |
| inputs=[target_image, texts, model_id, conf_thresh], | |
| visible=False, cache_examples=False, label="Text Prompt Examples") | |
| image_examples = gr.Examples( | |
| examples=image_examples_list, | |
| inputs=[target_image, prompt_image, model_id, conf_thresh, iou_thresh], | |
| visible=True, cache_examples=False, label="Box Visual Prompt Examples") | |
| # Examples update | |
| def update_text_examples(): | |
| return gr.Dataset(visible=True), gr.Dataset(visible=False), gr.update(visible=False) | |
| def update_visual_examples(): | |
| return gr.Dataset(visible=False), gr.Dataset(visible=True), gr.update(visible=True) | |
| text_tab.select( | |
| fn=update_text_examples, | |
| inputs=None, | |
| outputs=[text_examples.dataset, image_examples.dataset, iou_thresh] | |
| ) | |
| visual_tab.select( | |
| fn=update_visual_examples, | |
| inputs=None, | |
| outputs=[text_examples.dataset, image_examples.dataset, iou_thresh] | |
| ) | |
| return target_image, prompt_image, model_id, conf_thresh, iou_thresh, image_examples_list | |
| gradio_app = gr.Blocks() | |
| with gradio_app: | |
| gr.HTML( | |
| """ | |
| <h1 style='text-align: center'>OWLv2: Zero-shot detection with visual prompt π</h1> | |
| """) | |
| gr.Markdown(""" | |
| This demo showcases the OWLv2 model's ability to perform zero-shot object detection using visual and text prompts. | |
| You can either provide a text prompt or an image as a visual prompt to detect objects in the target image. | |
| Additionally, it compares different approaches for selecting a query embedding from a visual prompt. The method used in Hugging Face's `transformers` by default often underperforms because of how the visual prompt embedding is selected (see README.md for more details). | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| # Create a list of all UI components | |
| ui_components = app() | |
| # Unpack the components | |
| target_image, prompt_image, model_id, conf_thresh, iou_thresh, image_examples_list = ui_components | |
| gradio_app.load( | |
| fn=lambda: image_examples_list[1], | |
| outputs=[target_image, prompt_image, model_id, conf_thresh, iou_thresh] | |
| ) | |
| gradio_app.launch(allowed_paths=["figures"]) | |