Rix_Detection / app.py
NV9523's picture
Update app.py
1379c9a verified
import gradio as gr
from ultralytics import YOLO
import cv2
import numpy as np
# ======================================================
# Load YOLO model
# ======================================================
model = YOLO("rix_reg.pt") # change to your model
def get_model_names():
if hasattr(model, "names") and model.names is not None:
return model.names
if hasattr(model, "model") and hasattr(model.model, "names"):
return model.model.names
return {}
# ======================================================
# Function to count all objects in the model
# ======================================================
def count_objects(results):
names = get_model_names()
counter = {}
for r in results:
for cls_id in r.boxes.cls:
cls_id = int(cls_id)
label = str(names[cls_id])
# increment count
if label not in counter:
counter[label] = 1
else:
counter[label] += 1
counter["Total"] = sum(counter.get(k, 0) for k in counter)
return counter
# ======================================================
# Tab 1 - Image processing
# ======================================================
def detect_image(img):
results = model.predict(img, imgsz=640)
annotated = results[0].plot()
dashboard = count_objects(results)
return annotated, dashboard
# ======================================================
# Tab 2 - Video processing
# ======================================================
def detect_video(video_path):
cap = cv2.VideoCapture(video_path)
ret, frame = cap.read()
if not ret:
return None, {"Error": "Cannot read video"}
# demo first frame
results = model.predict(frame, imgsz=640)
annotated = results[0].plot()
dashboard = count_objects(results)
cap.release()
return annotated, dashboard
# ======================================================
# Tab 3 - Live camera
# ======================================================
def detect_camera(frame):
results = model.predict(frame, imgsz=640)
annotated = results[0].plot()
dashboard = count_objects(results)
return annotated, dashboard
# ======================================================
# GRADIO interface
# ======================================================
with gr.Blocks(title="Rix Detection") as demo:
gr.Markdown("## ๐Ÿ› ๏ธ Object Counting Dashboard")
with gr.Tabs():
# ==================== TAB 1 ====================
with gr.Tab("Image Detection"):
img_input = gr.Image(type="numpy", label="Upload Image")
img_out = gr.Image(label="Result Image")
dashboard1 = gr.JSON(label="Counts")
btn1 = gr.Button("Detect")
btn1.click(
fn=detect_image,
inputs=img_input,
outputs=[img_out, dashboard1]
)
# ==================== TAB 2 ====================
with gr.Tab("Video Detection"):
video_input = gr.Video(label="Upload Video")
video_out = gr.Image(label="Demo Frame Result")
dashboard2 = gr.JSON(label="Counts")
btn2 = gr.Button("Detect Video")
btn2.click(
fn=detect_video,
inputs=video_input,
outputs=[video_out, dashboard2]
)
# ==================== TAB 3 ====================
with gr.Tab("Live Camera"):
cam_input = gr.Image(sources=["webcam"], type="numpy", label="Camera")
cam_out = gr.Image(label="Real-time Result")
dashboard3 = gr.JSON(label="Counts")
cam_input.stream(
fn=detect_camera,
inputs=cam_input,
outputs=[cam_out, dashboard3]
)
demo.launch()