230838D / app.py
Cedri's picture
Update app.py
a2dd7f7 verified
from ultralytics import YOLO
from PIL import Image
import gradio as gr
from huggingface_hub import snapshot_download
import tempfile
import os
import cv2
# Load the YOLO model from Hugging Face
def load_model(repo_id):
download_dir = snapshot_download(repo_id)
model_path = os.path.join(download_dir, "best.pt")
return YOLO(model_path)
# Process image input
def predict_image(image, conf_threshold, iou_threshold):
result = detection_model.predict(image, conf=conf_threshold, iou=iou_threshold)
img_bgr = result[0].plot()
return Image.fromarray(img_bgr[..., ::-1])
# Process video input
def predict_video(video_path, conf_threshold, iou_threshold):
cap = cv2.VideoCapture(video_path)
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = cap.get(cv2.CAP_PROP_FPS)
out_path = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name
out_writer = cv2.VideoWriter(out_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
result = detection_model.predict(frame, conf=conf_threshold, iou=iou_threshold)
annotated = result[0].plot()
out_writer.write(annotated)
cap.release()
out_writer.release()
return out_path
# Load model
REPO_ID = "Cedri/battery_key_yolov8"
detection_model = load_model(REPO_ID)
# Gradio UI
with gr.Blocks() as demo:
gr.Markdown("## Battery Key Detection - Image & Video")
with gr.Tabs():
with gr.TabItem("Image"):
with gr.Row():
img_input = gr.Image(type="pil", label="Upload Image")
img_output = gr.Image(type="pil", label="Predicted Image")
conf_slider_img = gr.Slider(0.1, 1.0, 0.5, step=0.05, label="Confidence Threshold")
iou_slider_img = gr.Slider(0.1, 1.0, 0.6, step=0.05, label="IoU Threshold")
run_btn_img = gr.Button("Run Detection on Image")
run_btn_img.click(fn=predict_image, inputs=[img_input, conf_slider_img, iou_slider_img], outputs=img_output)
with gr.TabItem("Video"):
with gr.Row():
vid_input = gr.Video(label="Upload Video")
vid_output = gr.Video(label="Predicted Video")
conf_slider_vid = gr.Slider(0.1, 1.0, 0.5, step=0.05, label="Confidence Threshold")
iou_slider_vid = gr.Slider(0.1, 1.0, 0.6, step=0.05, label="IoU Threshold")
run_btn_vid = gr.Button("Run Detection on Video")
run_btn_vid.click(fn=predict_video, inputs=[vid_input, conf_slider_vid, iou_slider_vid], outputs=vid_output)
demo.launch()