243263STest / app.py
JingsAPI's picture
Reduce frame to run easier for video detection
89212d3
Raw
History Blame Contribute Delete
4.02 kB
import gradio as gr
from ultralytics import YOLO
from PIL import Image
import numpy as np
import cv2
import tempfile
import os
# Load model
model = YOLO("best.pt")
def detect_image(image):
# Run inference on image
results = model(image, conf=0.5, iou=0.6)
# Get annotated image
annotated = results[0].plot()
annotated_rgb = annotated[..., ::-1]
# Get detections text
detections = []
for box in results[0].boxes:
class_name = model.names[int(box.cls)]
confidence = float(box.conf)
detections.append(f"{class_name}: {confidence:.2f}")
detection_text = "\n".join(detections) if detections else "No objects detected"
return Image.fromarray(annotated_rgb), detection_text
def detect_video(video_path):
if video_path is None:
return None, "No video uploaded"
# Open video
cap = cv2.VideoCapture(video_path)
# Get video properties
original_fps = cap.get(cv2.CAP_PROP_FPS)
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
# Reduce resolution for CPU (resize to 640 width)
scale = 640 / width
new_width = 640
new_height = int(height * scale)
# Process every 3rd frame only (reduces from 30fps to 10fps)
# This makes it much faster on CPU
frame_skip = 3
output_fps = original_fps / frame_skip
# Create temp output file
temp_output = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
output_path = temp_output.name
temp_output.close()
# Video writer
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
out = cv2.VideoWriter(output_path, fourcc, output_fps, (new_width, new_height))
frame_count = 0
processed_count = 0
print(f"Processing video: {total_frames} total frames, skipping every {frame_skip} frames...")
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
# Only process every nth frame
if frame_count % frame_skip == 0:
# Resize frame to reduce resolution
frame_resized = cv2.resize(frame, (new_width, new_height))
# Run inference
results = model(frame_resized, conf=0.5, iou=0.6, verbose=False)
# Get annotated frame
annotated_frame = results[0].plot()
# Write to output
out.write(annotated_frame)
processed_count += 1
frame_count += 1
cap.release()
out.release()
return output_path, f"Done! Processed {processed_count} frames from {total_frames} total frames"
# Gradio UI
with gr.Blocks(title="243263S - Traffic Cone & Cardboard Box Detector") as demo:
gr.Markdown("# 243263S - Traffic Cone & Cardboard Box Detector")
gr.Markdown("Upload an image or video to detect **traffic cones** and **cardboard boxes**!")
with gr.Tab("Image Detection"):
with gr.Row():
image_input = gr.Image(type="numpy", label="Upload Image")
image_output = gr.Image(label="Detection Result")
detection_text = gr.Textbox(label="Detections")
image_btn = gr.Button("Detect!", variant="primary")
image_btn.click(
fn=detect_image,
inputs=image_input,
outputs=[image_output, detection_text]
)
with gr.Tab("Video Detection"):
gr.Markdown("⚠️ Video processing may take a few minutes on CPU. Please be patient!")
with gr.Row():
video_input = gr.Video(label="Upload Video")
video_output = gr.Video(label="Detection Result")
video_status = gr.Textbox(label="Status")
video_btn = gr.Button("Detect!", variant="primary")
video_btn.click(
fn=detect_video,
inputs=video_input,
outputs=[video_output, video_status]
)
demo.launch()