Spaces:
Running
Running
| import torch | |
| import gradio as gr | |
| from transformers import Owlv2Processor, Owlv2ForObjectDetection | |
| import cv2 | |
| import spaces | |
| # =============================== | |
| # DEVICE | |
| # =============================== | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model = Owlv2ForObjectDetection.from_pretrained( | |
| "google/owlv2-base-patch16-ensemble" | |
| ).to(device) | |
| processor = Owlv2Processor.from_pretrained( | |
| "google/owlv2-base-patch16-ensemble" | |
| ) | |
| # =============================== | |
| # MAIN FUNCTION | |
| # =============================== | |
| def query_image(img, text_queries, score_threshold): | |
| # Split queries (still required internally) | |
| text_queries = text_queries.split(",") | |
| # Prepare inputs | |
| inputs = processor( | |
| text=text_queries, | |
| images=img, | |
| return_tensors="pt" | |
| ).to(device) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| # Move outputs to CPU | |
| outputs.logits = outputs.logits.cpu() | |
| outputs.pred_boxes = outputs.pred_boxes.cpu() | |
| # Correct target size (IMPORTANT) | |
| target_sizes = torch.tensor([img.shape[:2]]) | |
| # ✅ FIXED METHOD (important!) | |
| results = processor.post_process_grounded_object_detection( | |
| outputs=outputs, | |
| target_sizes=target_sizes | |
| )[0] | |
| boxes = results["boxes"] | |
| scores = results["scores"] | |
| output_boxes = [] | |
| # Process detections | |
| for box, score in zip(boxes, scores): | |
| if score < score_threshold: | |
| continue | |
| x1, y1, x2, y2 = map(int, box.tolist()) | |
| # Save ONLY coordinates | |
| output_boxes.append([x1, y1, x2, y2]) | |
| # Draw rectangle ONLY (no labels) | |
| cv2.rectangle(img, (x1, y1), (x2, y2), (0, 255, 0), 2) | |
| return img, output_boxes | |
| # =============================== | |
| # GRADIO UI | |
| # =============================== | |
| demo = gr.Interface( | |
| fn=query_image, | |
| inputs=[ | |
| gr.Image(type="numpy"), | |
| gr.Textbox(label="Classes (comma separated)"), | |
| gr.Slider(0, 1, value=0.1) | |
| ], | |
| outputs=[ | |
| gr.Image(label="Bounding Boxes"), | |
| gr.JSON(label="Coordinates Only") | |
| ], | |
| title="OWLv2 Bounding Box Coordinates Only" | |
| ) | |
| # Launch app | |
| demo.launch() |