Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from ultralytics import YOLO | |
| import cv2 | |
| import numpy as np | |
| # ====================================================== | |
| # Load YOLO model | |
| # ====================================================== | |
| model = YOLO("rix_reg.pt") # change to your model | |
| def get_model_names(): | |
| if hasattr(model, "names") and model.names is not None: | |
| return model.names | |
| if hasattr(model, "model") and hasattr(model.model, "names"): | |
| return model.model.names | |
| return {} | |
| # ====================================================== | |
| # Function to count all objects in the model | |
| # ====================================================== | |
| def count_objects(results): | |
| names = get_model_names() | |
| counter = {} | |
| for r in results: | |
| for cls_id in r.boxes.cls: | |
| cls_id = int(cls_id) | |
| label = str(names[cls_id]) | |
| # increment count | |
| if label not in counter: | |
| counter[label] = 1 | |
| else: | |
| counter[label] += 1 | |
| counter["Total"] = sum(counter.get(k, 0) for k in counter) | |
| return counter | |
| # ====================================================== | |
| # Tab 1 - Image processing | |
| # ====================================================== | |
| def detect_image(img): | |
| results = model.predict(img, imgsz=640) | |
| annotated = results[0].plot() | |
| dashboard = count_objects(results) | |
| return annotated, dashboard | |
| # ====================================================== | |
| # Tab 2 - Video processing | |
| # ====================================================== | |
| def detect_video(video_path): | |
| cap = cv2.VideoCapture(video_path) | |
| ret, frame = cap.read() | |
| if not ret: | |
| return None, {"Error": "Cannot read video"} | |
| # demo first frame | |
| results = model.predict(frame, imgsz=640) | |
| annotated = results[0].plot() | |
| dashboard = count_objects(results) | |
| cap.release() | |
| return annotated, dashboard | |
| # ====================================================== | |
| # Tab 3 - Live camera | |
| # ====================================================== | |
| def detect_camera(frame): | |
| results = model.predict(frame, imgsz=640) | |
| annotated = results[0].plot() | |
| dashboard = count_objects(results) | |
| return annotated, dashboard | |
| # ====================================================== | |
| # GRADIO interface | |
| # ====================================================== | |
| with gr.Blocks(title="Rix Detection") as demo: | |
| gr.Markdown("## ๐ ๏ธ Object Counting Dashboard") | |
| with gr.Tabs(): | |
| # ==================== TAB 1 ==================== | |
| with gr.Tab("Image Detection"): | |
| img_input = gr.Image(type="numpy", label="Upload Image") | |
| img_out = gr.Image(label="Result Image") | |
| dashboard1 = gr.JSON(label="Counts") | |
| btn1 = gr.Button("Detect") | |
| btn1.click( | |
| fn=detect_image, | |
| inputs=img_input, | |
| outputs=[img_out, dashboard1] | |
| ) | |
| # ==================== TAB 2 ==================== | |
| with gr.Tab("Video Detection"): | |
| video_input = gr.Video(label="Upload Video") | |
| video_out = gr.Image(label="Demo Frame Result") | |
| dashboard2 = gr.JSON(label="Counts") | |
| btn2 = gr.Button("Detect Video") | |
| btn2.click( | |
| fn=detect_video, | |
| inputs=video_input, | |
| outputs=[video_out, dashboard2] | |
| ) | |
| # ==================== TAB 3 ==================== | |
| with gr.Tab("Live Camera"): | |
| cam_input = gr.Image(sources=["webcam"], type="numpy", label="Camera") | |
| cam_out = gr.Image(label="Real-time Result") | |
| dashboard3 = gr.JSON(label="Counts") | |
| cam_input.stream( | |
| fn=detect_camera, | |
| inputs=cam_input, | |
| outputs=[cam_out, dashboard3] | |
| ) | |
| demo.launch() | |