| | import torch |
| | import cv2 |
| | import gradio as gr |
| | import numpy as np |
| | import requests |
| | from PIL import Image |
| | from io import BytesIO |
| | from transformers import OwlViTProcessor, OwlViTForObjectDetection |
| |
|
| |
|
| | |
| | if torch.cuda.is_available(): |
| | device = torch.device("cuda") |
| | else: |
| | device = torch.device("cpu") |
| |
|
| | model = OwlViTForObjectDetection.from_pretrained("google/owlvit-large-patch14").to(device) |
| | model.eval() |
| | processor = OwlViTProcessor.from_pretrained("google/owlvit-large-patch14") |
| |
|
| |
|
| | def query_image(img_url, text_queries, score_threshold): |
| | text_queries = text_queries.split(",") |
| |
|
| | response = requests.get(img_url) |
| | img = Image.open(BytesIO(response.content)) |
| | img = np.array(img) |
| |
|
| | target_sizes = torch.Tensor([img.shape[:2]]) |
| | inputs = processor(text=text_queries, images=img, return_tensors="pt").to(device) |
| |
|
| | with torch.no_grad(): |
| | outputs = model(**inputs) |
| |
|
| | outputs.logits = outputs.logits.cpu() |
| | outputs.pred_boxes = outputs.pred_boxes.cpu() |
| | results = processor.post_process(outputs=outputs, target_sizes=target_sizes) |
| | boxes, scores, labels = results[0]["boxes"], results[0]["scores"], results[0]["labels"] |
| |
|
| | font = cv2.FONT_HERSHEY_SIMPLEX |
| |
|
| | for box, score, label in zip(boxes, scores, labels): |
| | box = [int(i) for i in box.tolist()] |
| |
|
| | if score >= score_threshold: |
| | img = cv2.rectangle(img, box[:2], box[2:], (255,0,0), 5) |
| | if box[3] + 25 > 768: |
| | y = box[3] - 10 |
| | else: |
| | y = box[3] + 25 |
| |
|
| | img = cv2.putText( |
| | img, text_queries[label], (box[0], y), font, 1, (255,0,0), 2, cv2.LINE_AA |
| | ) |
| | return img |
| |
|
| |
|
| | description = """ |
| | DEMO |
| | """ |
| | demo = gr.Interface( |
| | query_image, |
| | inputs=["text", "text", gr.Slider(0, 1, value=0.1)], |
| | outputs="image", |
| | title="Zero-Shot Object Detection with OWL-ViT", |
| | description=description, |
| | examples=[], |
| | ) |
| | demo.launch() |