Spaces:
Sleeping
Sleeping
File size: 3,876 Bytes
8a44470 90a85b5 1379c9a 8a44470 90a85b5 1379c9a 90a85b5 1379c9a 8a44470 1379c9a 8a44470 1379c9a 8a44470 1379c9a 90a85b5 1379c9a 90a85b5 1379c9a 90a85b5 1379c9a 90a85b5 1379c9a 90a85b5 1379c9a 8a44470 90a85b5 1379c9a 90a85b5 1379c9a 90a85b5 1379c9a 90a85b5 8a44470 1379c9a 90a85b5 1379c9a 90a85b5 1379c9a 8a44470 1379c9a 8a44470 90a85b5 1379c9a 90a85b5 8a44470 90a85b5 8a44470 1379c9a 8a44470 1379c9a 8a44470 1379c9a 8a44470 e1d2b3a 1379c9a af0a2d0 1379c9a 90a85b5 1379c9a 8a44470 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import gradio as gr
from ultralytics import YOLO
import cv2
import numpy as np
# ======================================================
# Load YOLO model
# ======================================================
model = YOLO("rix_reg.pt") # change to your model
def get_model_names():
if hasattr(model, "names") and model.names is not None:
return model.names
if hasattr(model, "model") and hasattr(model.model, "names"):
return model.model.names
return {}
# ======================================================
# Function to count all objects in the model
# ======================================================
def count_objects(results):
names = get_model_names()
counter = {}
for r in results:
for cls_id in r.boxes.cls:
cls_id = int(cls_id)
label = str(names[cls_id])
# increment count
if label not in counter:
counter[label] = 1
else:
counter[label] += 1
counter["Total"] = sum(counter.get(k, 0) for k in counter)
return counter
# ======================================================
# Tab 1 - Image processing
# ======================================================
def detect_image(img):
results = model.predict(img, imgsz=640)
annotated = results[0].plot()
dashboard = count_objects(results)
return annotated, dashboard
# ======================================================
# Tab 2 - Video processing
# ======================================================
def detect_video(video_path):
cap = cv2.VideoCapture(video_path)
ret, frame = cap.read()
if not ret:
return None, {"Error": "Cannot read video"}
# demo first frame
results = model.predict(frame, imgsz=640)
annotated = results[0].plot()
dashboard = count_objects(results)
cap.release()
return annotated, dashboard
# ======================================================
# Tab 3 - Live camera
# ======================================================
def detect_camera(frame):
results = model.predict(frame, imgsz=640)
annotated = results[0].plot()
dashboard = count_objects(results)
return annotated, dashboard
# ======================================================
# GRADIO interface
# ======================================================
with gr.Blocks(title="Rix Detection") as demo:
gr.Markdown("## 🛠️ Object Counting Dashboard")
with gr.Tabs():
# ==================== TAB 1 ====================
with gr.Tab("Image Detection"):
img_input = gr.Image(type="numpy", label="Upload Image")
img_out = gr.Image(label="Result Image")
dashboard1 = gr.JSON(label="Counts")
btn1 = gr.Button("Detect")
btn1.click(
fn=detect_image,
inputs=img_input,
outputs=[img_out, dashboard1]
)
# ==================== TAB 2 ====================
with gr.Tab("Video Detection"):
video_input = gr.Video(label="Upload Video")
video_out = gr.Image(label="Demo Frame Result")
dashboard2 = gr.JSON(label="Counts")
btn2 = gr.Button("Detect Video")
btn2.click(
fn=detect_video,
inputs=video_input,
outputs=[video_out, dashboard2]
)
# ==================== TAB 3 ====================
with gr.Tab("Live Camera"):
cam_input = gr.Image(sources=["webcam"], type="numpy", label="Camera")
cam_out = gr.Image(label="Real-time Result")
dashboard3 = gr.JSON(label="Counts")
cam_input.stream(
fn=detect_camera,
inputs=cam_input,
outputs=[cam_out, dashboard3]
)
demo.launch()
|