gannushalini2006's picture
Update app.py
7f51efa verified
import gradio as gr
import torch
import numpy as np
import cv2
from PIL import Image
from ultralytics import YOLO
from torchvision.models.detection import fasterrcnn_resnet50_fpn
from transformers import AutoImageProcessor, AutoModelForObjectDetection
# -------------------------------------------------
# Device
# -------------------------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# -------------------------------------------------
# Load Models
# -------------------------------------------------
# YOLOv8
yolo = YOLO("yolov8n.pt")
# Faster R-CNN
frcnn = fasterrcnn_resnet50_fpn(pretrained=True)
frcnn.to(device).eval()
# Deformable DETR
processor = AutoImageProcessor.from_pretrained(
"SenseTime/deformable-detr",
use_fast=False
)
detr = AutoModelForObjectDetection.from_pretrained(
"SenseTime/deformable-detr"
)
detr.to(device).eval()
# -------------------------------------------------
# Utility Functions
# -------------------------------------------------
def compute_iou(box1, box2):
x1, y1 = max(box1[0], box2[0]), max(box1[1], box2[1])
x2, y2 = min(box1[2], box2[2]), min(box1[3], box2[3])
inter = max(0, x2 - x1) * max(0, y2 - y1)
area1 = (box1[2]-box1[0])*(box1[3]-box1[1])
area2 = (box2[2]-box2[0])*(box2[3]-box2[1])
return inter / (area1 + area2 - inter + 1e-6)
def draw_boxes(image, detections):
img = np.array(image)
for d in detections:
x1, y1, x2, y2 = map(int, d["box"])
label = d["label"]
cv2.rectangle(img, (x1,y1), (x2,y2), (0,255,0), 2)
cv2.putText(
img, label, (x1, y1-6),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0,255,0), 1
)
return Image.fromarray(img)
# -------------------------------------------------
# Model Inference
# -------------------------------------------------
def yolo_detect(image):
results = yolo(image)[0]
dets = []
for b in results.boxes:
dets.append({
"box": b.xyxy[0].cpu().numpy(),
"model": "YOLO"
})
return dets
def frcnn_detect(image):
img = torch.tensor(np.array(image)/255.).permute(2,0,1).float()
img = img.unsqueeze(0).to(device)
with torch.no_grad():
out = frcnn(img)[0]
dets = []
for box, score in zip(out["boxes"], out["scores"]):
if score > 0.6:
dets.append({
"box": box.cpu().numpy(),
"model": "FRCNN"
})
return dets
def detr_detect(image):
inputs = processor(images=image, return_tensors="pt").to(device)
with torch.no_grad():
outputs = detr(**inputs)
size = torch.tensor([image.size[::-1]]).to(device)
results = processor.post_process_object_detection(
outputs, target_sizes=size, threshold=0.7
)[0]
dets = []
for box in results["boxes"]:
dets.append({
"box": box.cpu().numpy(),
"model": "DETR"
})
return dets
# -------------------------------------------------
# HARD VOTING
# -------------------------------------------------
def hard_vote(detections, vote_thresh=2, iou_thresh=0.5):
final = []
for d in detections:
votes = [d]
for o in detections:
if d["model"] != o["model"]:
if compute_iou(d["box"], o["box"]) >= iou_thresh:
votes.append(o)
models = set(v["model"] for v in votes)
if len(models) >= vote_thresh:
avg_box = np.mean([v["box"] for v in votes], axis=0)
final.append({
"box": avg_box,
"label": f"Ensemble ({len(models)})"
})
# remove duplicates
unique = []
for d in final:
if not any(compute_iou(d["box"], u["box"]) > 0.8 for u in unique):
unique.append(d)
return unique
# -------------------------------------------------
# LIVE FRAME FUNCTION
# -------------------------------------------------
def live_detect(frame):
image = Image.fromarray(frame)
detections = (
yolo_detect(image) +
frcnn_detect(image) +
detr_detect(image)
)
voted = hard_vote(detections)
output = draw_boxes(image, voted)
return np.array(output)
# -------------------------------------------------
# Gradio Interface (Webcam)
# -------------------------------------------------
demo = gr.Interface(
fn=live_detect,
inputs=gr.Image(source="webcam", streaming=True),
outputs=gr.Image(),
live=True,
title="Live Object Detection – Hard Voting Ensemble",
description=(
"YOLOv8 + Faster R-CNN + Deformable DETR\n"
"Browser-based webcam with IoU-based hard voting."
)
)
demo.launch()