|
|
import gradio as gr |
|
|
import torch |
|
|
from qai_hub_models.models.detr_resnet50 import Model |
|
|
from PIL import Image, ImageDraw |
|
|
import numpy as np |
|
|
|
|
|
|
|
|
try: |
|
|
from pillow_avif import register_avif_opener |
|
|
register_avif_opener() |
|
|
except ImportError: |
|
|
try: |
|
|
import pillow_heif |
|
|
pillow_heif.register_heif_opener() |
|
|
except ImportError: |
|
|
print("AVIF support not available. Please install 'pillow-avif-plugin' or 'pillow-heif'.") |
|
|
|
|
|
|
|
|
torch_model = Model.from_pretrained() |
|
|
|
|
|
def detect_objects(image): |
|
|
if image is None: |
|
|
raise ValueError("No image uploaded!") |
|
|
|
|
|
|
|
|
image = Image.fromarray(image).convert("RGB") |
|
|
original_image = image.copy() |
|
|
image = image.resize((800, 800)) |
|
|
|
|
|
|
|
|
image_tensor = torch.tensor(np.array(image)).permute(2, 0, 1) |
|
|
image_tensor = image_tensor.float() / 255.0 |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = torch_model(image_tensor.unsqueeze(0)) |
|
|
|
|
|
|
|
|
predictions = outputs['logits'] if 'logits' in outputs else outputs[0] |
|
|
|
|
|
|
|
|
detections = [] |
|
|
confidence_threshold = 0.8 |
|
|
|
|
|
for i in range(predictions.shape[1]): |
|
|
score = predictions[0, i, -1].item() |
|
|
if score > confidence_threshold: |
|
|
|
|
|
box = predictions[0, i, :-1].tolist() |
|
|
box[0] *= original_image.width / 800 |
|
|
box[1] *= original_image.height / 800 |
|
|
box[2] *= original_image.width / 800 |
|
|
box[3] *= original_image.height / 800 |
|
|
|
|
|
detections.append({ |
|
|
"label": f"Object {i}", |
|
|
"confidence": round(score, 3), |
|
|
"box": box, |
|
|
}) |
|
|
|
|
|
|
|
|
draw = ImageDraw.Draw(original_image) |
|
|
if box[1] < box[3]: |
|
|
draw.rectangle(box[:4], outline="red", width=3) |
|
|
draw.text((box[0], box[1]), f"{detections[-1]['label']} ({detections[-1]['confidence']})", fill="red") |
|
|
|
|
|
return original_image, detections |
|
|
|
|
|
|
|
|
with gr.Blocks() as iface: |
|
|
gr.Markdown("# Object Detection with DETR-ResNet50") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
image_input = gr.Image(type="numpy", label="Upload Image (supports PNG, JPEG, AVIF...)") |
|
|
submit_button = gr.Button("Submit") |
|
|
clear_button = gr.Button("Clear") |
|
|
with gr.Column(scale=1): |
|
|
output_image = gr.Image(label="Detected Image") |
|
|
output_json = gr.JSON(label="Detections") |
|
|
|
|
|
def on_submit(image): |
|
|
try: |
|
|
detected_image, detections = detect_objects(image) |
|
|
return detected_image, detections |
|
|
except Exception as e: |
|
|
return None, {"error": str(e)} |
|
|
|
|
|
def on_clear(): |
|
|
return None, None, None |
|
|
|
|
|
submit_button.click(on_submit, inputs=image_input, outputs=[output_image, output_json]) |
|
|
clear_button.click(on_clear, inputs=None, outputs=[image_input, output_image, output_json]) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
iface.launch() |
|
|
|