Nomi78600's picture
final
298b1a0 verified
import gradio as gr
from ultralytics import YOLO
from ultralytics.yolo.utils.plotting import Annotator # Import Annotator
import numpy as np
import cv2
import tempfile
import os
import torch
from huggingface_hub import hf_hub_download
# --- Model Loading ---
MODEL_REPO = "Nomi78600/keyboard-mouse-detection-yolov8"
MODEL_FILENAME = "best.pt"
# --- The Definitive "One-Step" Solution ---
# The environment uses PyTorch 2.6+, which defaults to `weights_only=True` in `torch.load` for security.
# Our model file contains the full architecture and fails this check. The whitelist method has proven unreliable.
# This is the alternative solution suggested by the error message itself.
# We will temporarily modify (monkey-patch) `torch.load` to force `weights_only=False`.
# This is safe because we are loading our own trusted model file.
# 1. Store the original `torch.load` function
original_torch_load = torch.load
# 2. Define a new wrapper function that forces the desired argument
def new_torch_load(*args, **kwargs):
kwargs['weights_only'] = False
return original_torch_load(*args, **kwargs)
# 3. Replace the global `torch.load` with our wrapper
torch.load = new_torch_load
# --- End of Solution ---
try:
# Explicitly download the model file from the Hugging Face Hub
model_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILENAME)
# Load the model from the local path. The YOLO() call will now use our modified `torch.load` internally.
model = YOLO(model_path)
# (Optional but good practice) Restore the original torch.load after we are done
torch.load = original_torch_load
except Exception as e:
# If there's an error, display it in the Gradio interface
with gr.Blocks() as app:
gr.Markdown("# ❌ Model Loading Error")
gr.Markdown(f"Could not load model (`{MODEL_FILENAME}`) from Hugging Face Hub repo: `{MODEL_REPO}`.")
gr.Markdown("Please ensure the repository is public, the file exists, and you have disabled gated access.")
gr.Textbox(str(e), label="Error Details")
app.launch()
# Stop the script if the model fails to load
exit()
# --- Detection Logic ---
def process_video(video_path, confidence, progress=gr.Progress()):
"""
Processes a video file to detect objects in each frame and returns the path to the processed video.
"""
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
raise gr.Error("Could not open video file.")
# Get video properties
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = int(cap.get(cv2.CAP_PROP_FPS))
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
# Create a temporary file to save the output video
temp_dir = tempfile.mkdtemp()
output_path = os.path.join(temp_dir, "output.mp4")
# Define the codec and create VideoWriter object
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
progress(0, desc="Starting video processing...")
frame_count = 0
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
# Perform detection
results = model(frame, conf=confidence)
# Use Annotator to draw boxes
annotator = Annotator(frame.copy())
if results[0].boxes:
for box in results[0].boxes:
b = box.xyxy[0]
c = box.cls
label = f"{model.names[int(c)]} {float(box.conf):.2f}"
annotator.box_label(b, label)
annotated_frame = annotator.result()
# Write the annotated frame to the output video
out.write(annotated_frame)
frame_count += 1
progress(frame_count / total_frames, desc=f"Processing frame {frame_count}/{total_frames}")
# Release everything
cap.release()
out.release()
return output_path
def detect(image, confidence):
"""
Performs object detection on a single image.
Args:
image (np.ndarray): The input image as a NumPy array.
confidence (float): The confidence threshold for detection.
Returns:
np.ndarray: The image with detected objects annotated.
"""
if image is None:
return None
# Run detection
results = model(image, conf=confidence)
# Use Annotator to draw boxes
annotator = Annotator(image.copy())
if results[0].boxes:
for box in results[0].boxes:
b = box.xyxy[0]
c = box.cls
label = f"{model.names[int(c)]} {float(box.conf):.2f}"
annotator.box_label(b, label)
annotated_image = annotator.result()
return annotated_image
# --- Gradio Interface ---
with gr.Blocks(theme=gr.themes.Soft()) as app:
gr.Markdown(
"""
# 🖱️ Keyboard & Mouse Detection ⌨️
Powered by **YOLOv8** and deployed on **Hugging Face Spaces**.
This application detects keyboards and mice from three sources:
1. **Image:** Upload a static image.
2. **Video:** Upload a video file for frame-by-frame analysis.
3. **Webcam:** Use your webcam for live, real-time detection.
Adjust the **Confidence Threshold** to filter detections.
"""
)
with gr.Row():
confidence_slider = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.5,
step=0.05,
label="Confidence Threshold",
info="Lower values detect more objects; higher values are more selective."
)
with gr.Tabs():
# Image Detection Tab
with gr.TabItem("🖼️ Image Detection"):
with gr.Row():
image_input = gr.Image(type="numpy", label="Upload Image")
image_output = gr.Image(type="numpy", label="Detected Image")
image_button = gr.Button("Detect Objects", variant="primary")
image_button.click(
fn=detect,
inputs=[image_input, confidence_slider],
outputs=image_output
)
# Video Detection Tab
with gr.TabItem("▶️ Video Detection"):
with gr.Row():
video_input = gr.Video(label="Upload Video")
video_output = gr.Video(label="Detected Video")
video_button = gr.Button("Process Video", variant="primary")
video_button.click(
fn=process_video,
inputs=[video_input, confidence_slider],
outputs=video_output
)
# Webcam Detection Tab
with gr.TabItem("📷 Live Webcam"):
with gr.Row():
webcam_input = gr.Image(sources=["webcam"], type="numpy", label="Webcam Feed")
webcam_output = gr.Image(label="Webcam Detection")
# Use the 'change' event to process frames as they come in from the webcam
webcam_input.change(
fn=detect,
inputs=[webcam_input, confidence_slider],
outputs=webcam_output
)
gr.Markdown(
"""
---
*Model hosted on [Hugging Face Hub](https://huggingface.co/Nomi78600/keyboard-mouse-detection-yolov8)*
"""
)
if __name__ == "__main__":
app.launch()