Zynaly's picture
Update app.py
4380f79 verified
import cv2
import numpy as np
from ultralytics import YOLO
import gradio as gr
import tempfile
import os
import cv2
import tempfile
"""Load the trained model"""
model = YOLO("mouse_paw_detection.onnx")
# -------------------- IMAGE DETECTION FUNCTION --------------------
def detect_paws_image(image):
img = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
results = model(img)
paw_count = 0
for result in results:
for box in result.boxes:
x1, y1, x2, y2 = map(int, box.xyxy[0])
conf = box.conf[0]
paw_count += 1
cv2.rectangle(img, (x1, y1), (x2, y2), (0, 255, 0), 2)
cv2.putText(img, f'Paw {conf:.2f}', (x1, y1 - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
return cv2.cvtColor(img, cv2.COLOR_BGR2RGB), f"Detected Paws: {paw_count}"
# -------------------- VIDEO DETECTION FUNCTION --------------------
def detect_paws_video(video_path):
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
return None, "Error opening video file."
fps = int(cap.get(cv2.CAP_PROP_FPS))
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
temp_output = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(temp_output.name, fourcc, fps, (width, height))
total_paws = 0
total_frames = 0
while True:
ret, frame = cap.read()
if not ret:
break
results = model(frame)
frame_paws = 0
for result in results:
for box in result.boxes:
x1, y1, x2, y2 = map(int, box.xyxy[0])
conf = box.conf[0]
frame_paws += 1
cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
cv2.putText(frame, f'Paw {conf:.2f}', (x1, y1 - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
total_paws += frame_paws
total_frames += 1
out.write(frame)
cap.release()
out.release()
avg_paws = total_paws // total_frames if total_frames > 0 else 0
return temp_output.name, f"Average Paws Per Frame: {avg_paws}"
# -------------------- GRADIO INTERFACES --------------------
"""Image interface"""
image_interface = gr.Interface(
fn=detect_paws_image,
inputs=gr.Image(type="numpy", label="Upload Image"),
outputs=[
gr.Image(type="numpy", label="Detected Output"),
gr.Label(label="Paw Count")
],
title="Mouse Paw Detection - Image"
)
"""Video interface"""
video_interface = gr.Interface(
fn=detect_paws_video,
inputs=gr.Video(label="Upload Video"),
outputs=[
gr.Video(label="Processed Video"),
gr.Label(label="Average Paw Count")
],
title="Mouse Paw Detection - Video"
)
"""Combine both interfaces"""
app = gr.TabbedInterface(
[image_interface, video_interface],
["Image Detection", "Video Detection"]
)
app.launch()