|
|
import gradio as gr |
|
|
import cv2 |
|
|
import numpy as np |
|
|
from ultralytics import YOLO |
|
|
import threading |
|
|
import time |
|
|
import os |
|
|
import tempfile |
|
|
from pathlib import Path |
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
if not os.path.exists("best.pt"): |
|
|
|
|
|
print("Warning: best.pt not found. Using default YOLOv8n model.") |
|
|
model = YOLO("yolov8n.pt") |
|
|
CLASS_NAMES = ["person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", |
|
|
"traffic light", "fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", |
|
|
"dog", "horse", "sheep", "cow", "elephant", "bear", "zebra", "giraffe", "backpack", |
|
|
"umbrella", "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard", "sports ball", |
|
|
"kite", "baseball bat", "baseball glove", "skateboard", "surfboard", "tennis racket", |
|
|
"bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple", |
|
|
"sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", |
|
|
"chair", "couch", "potted plant", "bed", "dining table", "toilet", "tv", "laptop", |
|
|
"mouse", "remote", "keyboard", "cell phone", "microwave", "oven", "toaster", "sink", |
|
|
"refrigerator", "book", "clock", "vase", "scissors", "teddy bear", "hair drier", "toothbrush"] |
|
|
else: |
|
|
model = YOLO("best.pt") |
|
|
CLASS_NAMES = ["hard hat", "mask"] |
|
|
except Exception as e: |
|
|
print(f"Error loading model: {e}") |
|
|
model = YOLO("yolov8n.pt") |
|
|
CLASS_NAMES = ["person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", |
|
|
"traffic light", "fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", |
|
|
"dog", "horse", "sheep", "cow", "elephant", "bear", "zebra", "giraffe", "backpack", |
|
|
"umbrella", "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard", "sports ball", |
|
|
"kite", "baseball bat", "baseball glove", "skateboard", "surfboard", "tennis racket", |
|
|
"bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple", |
|
|
"sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", |
|
|
"chair", "couch", "potted plant", "bed", "dining table", "toilet", "tv", "laptop", |
|
|
"mouse", "remote", "keyboard", "cell phone", "microwave", "oven", "toaster", "sink", |
|
|
"refrigerator", "book", "clock", "vase", "scissors", "teddy bear", "hair drier", "toothbrush"] |
|
|
|
|
|
|
|
|
camera_active = False |
|
|
current_frame = None |
|
|
frame_lock = threading.Lock() |
|
|
|
|
|
|
|
|
temp_dir = tempfile.mkdtemp() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def predict_image(input_image, selected_classes): |
|
|
if input_image is None: |
|
|
return None, "No image uploaded" |
|
|
|
|
|
try: |
|
|
|
|
|
if isinstance(input_image, str): |
|
|
frame = cv2.imread(input_image) |
|
|
if frame is None: |
|
|
return None, "Could not load image file" |
|
|
elif isinstance(input_image, np.ndarray): |
|
|
frame = input_image |
|
|
else: |
|
|
frame = np.array(input_image) |
|
|
|
|
|
|
|
|
if len(frame.shape) == 3 and frame.shape[2] == 3: |
|
|
|
|
|
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) |
|
|
|
|
|
|
|
|
results = model.predict(frame, conf=0.25, verbose=False) |
|
|
frame_out = frame.copy() |
|
|
|
|
|
detection_count = {cls: 0 for cls in selected_classes} |
|
|
|
|
|
for r in results: |
|
|
if r.boxes is not None: |
|
|
for box in r.boxes: |
|
|
cls_id = int(box.cls[0]) |
|
|
conf = float(box.conf[0]) |
|
|
label = CLASS_NAMES[cls_id] if cls_id < len(CLASS_NAMES) else f"cls{cls_id}" |
|
|
|
|
|
if label in selected_classes: |
|
|
x1, y1, x2, y2 = map(int, box.xyxy[0]) |
|
|
cv2.rectangle(frame_out, (x1, y1), (x2, y2), (0, 255, 0), 2) |
|
|
cv2.putText(frame_out, f"{label} {conf:.2f}", |
|
|
(x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.6, |
|
|
(0, 255, 0), 2) |
|
|
detection_count[label] += 1 |
|
|
|
|
|
|
|
|
frame_rgb = cv2.cvtColor(frame_out, cv2.COLOR_BGR2RGB) |
|
|
|
|
|
|
|
|
tally_text = "\n".join([f"{cls}: {count} detections" for cls, count in detection_count.items()]) |
|
|
|
|
|
return frame_rgb, tally_text |
|
|
|
|
|
except Exception as e: |
|
|
return None, f"Error processing image: {str(e)}" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def predict_video(input_file, selected_classes): |
|
|
if input_file is None: |
|
|
return None, "No file uploaded" |
|
|
|
|
|
try: |
|
|
cap = cv2.VideoCapture(input_file) |
|
|
if not cap.isOpened(): |
|
|
return None, "Could not read input file" |
|
|
|
|
|
|
|
|
fourcc = cv2.VideoWriter_fourcc(*"mp4v") |
|
|
fps = max(cap.get(cv2.CAP_PROP_FPS), 20) |
|
|
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) |
|
|
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) |
|
|
|
|
|
|
|
|
out_path = os.path.join(temp_dir, "output.mp4") |
|
|
out = cv2.VideoWriter(out_path, fourcc, fps, (width, height)) |
|
|
|
|
|
tally_counts = {cls: 0 for cls in selected_classes} |
|
|
frame_count = 0 |
|
|
max_frames = 1000 |
|
|
|
|
|
while frame_count < max_frames: |
|
|
ret, frame = cap.read() |
|
|
if not ret: |
|
|
break |
|
|
|
|
|
results = model.predict(frame, conf=0.25, verbose=False) |
|
|
frame_out = frame.copy() |
|
|
|
|
|
for r in results: |
|
|
if r.boxes is not None: |
|
|
for box in r.boxes: |
|
|
cls_id = int(box.cls[0]) |
|
|
conf = float(box.conf[0]) |
|
|
label = CLASS_NAMES[cls_id] if cls_id < len(CLASS_NAMES) else f"cls{cls_id}" |
|
|
|
|
|
if label in selected_classes: |
|
|
x1, y1, x2, y2 = map(int, box.xyxy[0]) |
|
|
cv2.rectangle(frame_out, (x1, y1), (x2, y2), (0, 255, 0), 2) |
|
|
cv2.putText(frame_out, f"{label} {conf:.2f}", |
|
|
(x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.6, |
|
|
(0, 255, 0), 2) |
|
|
tally_counts[label] += 1 |
|
|
|
|
|
out.write(frame_out) |
|
|
frame_count += 1 |
|
|
|
|
|
cap.release() |
|
|
out.release() |
|
|
|
|
|
if frame_count >= max_frames: |
|
|
tally_text = f"Processed first {max_frames} frames.\n" |
|
|
else: |
|
|
tally_text = "" |
|
|
|
|
|
tally_text += "\n".join([f"{cls}: {count} detections" for cls, count in tally_counts.items()]) |
|
|
|
|
|
return out_path, tally_text |
|
|
|
|
|
except Exception as e: |
|
|
return None, f"Error processing video: {str(e)}" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def camera_thread(): |
|
|
"""Background thread to capture camera frames""" |
|
|
global camera_active, current_frame |
|
|
|
|
|
try: |
|
|
cap = cv2.VideoCapture(0) |
|
|
if not cap.isOpened(): |
|
|
print("Warning: Could not open camera - this is expected on Hugging Face Spaces") |
|
|
return |
|
|
|
|
|
|
|
|
cap.set(cv2.CAP_PROP_FRAME_WIDTH, 640) |
|
|
cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 480) |
|
|
cap.set(cv2.CAP_PROP_FPS, 30) |
|
|
|
|
|
while camera_active: |
|
|
ret, frame = cap.read() |
|
|
if ret: |
|
|
with frame_lock: |
|
|
current_frame = frame.copy() |
|
|
time.sleep(0.033) |
|
|
|
|
|
cap.release() |
|
|
except Exception as e: |
|
|
print(f"Camera error: {e}") |
|
|
|
|
|
def start_camera(): |
|
|
"""Start the camera streaming""" |
|
|
global camera_active |
|
|
camera_active = True |
|
|
camera_thread_obj = threading.Thread(target=camera_thread, daemon=True) |
|
|
camera_thread_obj.start() |
|
|
return "Camera started (Note: Camera access may be limited on Hugging Face Spaces)" |
|
|
|
|
|
def stop_camera(): |
|
|
"""Stop the camera streaming""" |
|
|
global camera_active |
|
|
camera_active = False |
|
|
return "Camera stopped" |
|
|
|
|
|
def get_camera_frame(selected_classes): |
|
|
"""Get current camera frame with detections""" |
|
|
global current_frame |
|
|
|
|
|
if not camera_active or current_frame is None: |
|
|
return None, "Camera not available or no frame captured" |
|
|
|
|
|
try: |
|
|
with frame_lock: |
|
|
frame = current_frame.copy() |
|
|
|
|
|
|
|
|
results = model.predict(frame, conf=0.25, verbose=False) |
|
|
frame_out = frame.copy() |
|
|
|
|
|
detection_count = {cls: 0 for cls in selected_classes} |
|
|
|
|
|
for r in results: |
|
|
if r.boxes is not None: |
|
|
for box in r.boxes: |
|
|
cls_id = int(box.cls[0]) |
|
|
conf = float(box.conf[0]) |
|
|
label = CLASS_NAMES[cls_id] if cls_id < len(CLASS_NAMES) else f"cls{cls_id}" |
|
|
|
|
|
if label in selected_classes: |
|
|
x1, y1, x2, y2 = map(int, box.xyxy[0]) |
|
|
cv2.rectangle(frame_out, (x1, y1), (x2, y2), (0, 255, 0), 2) |
|
|
cv2.putText(frame_out, f"{label} {conf:.2f}", |
|
|
(x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.6, |
|
|
(0, 255, 0), 2) |
|
|
detection_count[label] += 1 |
|
|
|
|
|
|
|
|
frame_rgb = cv2.cvtColor(frame_out, cv2.COLOR_BGR2RGB) |
|
|
|
|
|
|
|
|
tally_text = "\n".join([f"{cls}: {count} detections" for cls, count in detection_count.items()]) |
|
|
|
|
|
return frame_rgb, tally_text |
|
|
|
|
|
except Exception as e: |
|
|
return None, f"Error processing camera frame: {str(e)}" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def update_ui(mode): |
|
|
if mode == "Upload Image": |
|
|
return ( |
|
|
gr.update(visible=True), |
|
|
gr.update(visible=False), |
|
|
gr.update(visible=False), |
|
|
gr.update(visible=True), |
|
|
gr.update(visible=True), |
|
|
gr.update(visible=False), |
|
|
gr.update(visible=False), |
|
|
gr.update(visible=False), |
|
|
gr.update(visible=False) |
|
|
) |
|
|
elif mode == "Upload Video": |
|
|
return ( |
|
|
gr.update(visible=False), |
|
|
gr.update(visible=True), |
|
|
gr.update(visible=True), |
|
|
gr.update(visible=False), |
|
|
gr.update(visible=True), |
|
|
gr.update(visible=False), |
|
|
gr.update(visible=False), |
|
|
gr.update(visible=False), |
|
|
gr.update(visible=False) |
|
|
) |
|
|
else: |
|
|
return ( |
|
|
gr.update(visible=False), |
|
|
gr.update(visible=False), |
|
|
gr.update(visible=False), |
|
|
gr.update(visible=True), |
|
|
gr.update(visible=False), |
|
|
gr.update(visible=True), |
|
|
gr.update(visible=True), |
|
|
gr.update(visible=True), |
|
|
gr.update(visible=True) |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks(title="YOLO Detector") as demo: |
|
|
gr.Markdown("## 🦺 YOLO Object Detector") |
|
|
gr.Markdown("Upload images or videos for object detection. Note: Live camera may not work on Hugging Face Spaces due to browser security restrictions.") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
mode = gr.Radio(["Upload Image", "Upload Video", "Live Camera"], value="Upload Image", label="Detection Mode") |
|
|
|
|
|
input_file = gr.File( |
|
|
label="Upload Image", |
|
|
type="filepath", |
|
|
file_types=[".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".webp"], |
|
|
visible=True |
|
|
) |
|
|
|
|
|
input_video = gr.File( |
|
|
label="Upload Video", |
|
|
type="filepath", |
|
|
file_types=[".mp4", ".avi", ".mov", ".mkv", ".wmv", ".flv", ".webm"], |
|
|
visible=False |
|
|
) |
|
|
|
|
|
|
|
|
available_classes = CLASS_NAMES[:10] if len(CLASS_NAMES) > 10 else CLASS_NAMES |
|
|
class_toggle = gr.CheckboxGroup( |
|
|
available_classes, |
|
|
value=available_classes[:2], |
|
|
label="Select classes to detect" |
|
|
) |
|
|
|
|
|
|
|
|
run_btn = gr.Button("Run Detection", variant="primary", visible=True) |
|
|
start_btn = gr.Button("Start Camera", visible=False) |
|
|
stop_btn = gr.Button("Stop Camera", visible=False) |
|
|
refresh_btn = gr.Button("Refresh Live Feed", visible=False) |
|
|
|
|
|
with gr.Column(): |
|
|
|
|
|
camera_warning = gr.Markdown( |
|
|
"⚠️ **Note:** Live camera access is typically not available on Hugging Face Spaces due to security restrictions. Use image/video upload instead.", |
|
|
visible=False |
|
|
) |
|
|
|
|
|
output_video = gr.Video(label="Detection Output", visible=False) |
|
|
output_img = gr.Image(type="numpy", label="Detection Output", visible=True) |
|
|
tally_box = gr.Textbox(label="Detection Count", interactive=False, lines=5) |
|
|
|
|
|
|
|
|
mode.change( |
|
|
update_ui, |
|
|
inputs=mode, |
|
|
outputs=[input_file, input_video, output_video, output_img, run_btn, start_btn, stop_btn, refresh_btn, camera_warning] |
|
|
) |
|
|
|
|
|
def run_detection(input_file, input_video, selected_classes, mode): |
|
|
if not selected_classes: |
|
|
return None, None, "Please select at least one class to detect" |
|
|
|
|
|
try: |
|
|
if mode == "Upload Image": |
|
|
if input_file is None: |
|
|
return None, None, "No image uploaded" |
|
|
result_img, tally = predict_image(input_file, selected_classes) |
|
|
return result_img, None, tally |
|
|
elif mode == "Upload Video": |
|
|
if input_video is None: |
|
|
return None, None, "No video uploaded" |
|
|
result_video, tally = predict_video(input_video, selected_classes) |
|
|
return None, result_video, tally |
|
|
return None, None, "Invalid mode selected" |
|
|
except Exception as e: |
|
|
return None, None, f"Error: {str(e)}" |
|
|
|
|
|
run_btn.click( |
|
|
run_detection, |
|
|
inputs=[input_file, input_video, class_toggle, mode], |
|
|
outputs=[output_img, output_video, tally_box] |
|
|
) |
|
|
|
|
|
|
|
|
start_btn.click(start_camera, outputs=tally_box) |
|
|
stop_btn.click(stop_camera, outputs=tally_box) |
|
|
|
|
|
|
|
|
def update_live_feed(selected_classes): |
|
|
if not selected_classes: |
|
|
return None, "Please select at least one class to detect" |
|
|
if camera_active: |
|
|
return get_camera_frame(selected_classes) |
|
|
return None, "Camera not active" |
|
|
|
|
|
refresh_btn.click( |
|
|
update_live_feed, |
|
|
inputs=[class_toggle], |
|
|
outputs=[output_img, tally_box] |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch( |
|
|
server_name="0.0.0.0", |
|
|
server_port=7860, |
|
|
share=False |
|
|
) |