File size: 2,700 Bytes
0a61c75 a2dd7f7 0a61c75 69a61f5 0a61c75 a2dd7f7 0a61c75 69a61f5 966b22f a2dd7f7 966b22f a2dd7f7 bf51705 a2dd7f7 966b22f a2dd7f7 966b22f a2dd7f7 966b22f a2dd7f7 bf51705 966b22f a2dd7f7 0a61c75 a2dd7f7 669867a a2dd7f7 966b22f a2dd7f7 966b22f a2dd7f7 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 | from ultralytics import YOLO
from PIL import Image
import gradio as gr
from huggingface_hub import snapshot_download
import tempfile
import os
import cv2
# Load the YOLO model from Hugging Face
def load_model(repo_id):
download_dir = snapshot_download(repo_id)
model_path = os.path.join(download_dir, "best.pt")
return YOLO(model_path)
# Process image input
def predict_image(image, conf_threshold, iou_threshold):
result = detection_model.predict(image, conf=conf_threshold, iou=iou_threshold)
img_bgr = result[0].plot()
return Image.fromarray(img_bgr[..., ::-1])
# Process video input
def predict_video(video_path, conf_threshold, iou_threshold):
cap = cv2.VideoCapture(video_path)
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = cap.get(cv2.CAP_PROP_FPS)
out_path = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name
out_writer = cv2.VideoWriter(out_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
result = detection_model.predict(frame, conf=conf_threshold, iou=iou_threshold)
annotated = result[0].plot()
out_writer.write(annotated)
cap.release()
out_writer.release()
return out_path
# Load model
REPO_ID = "Cedri/battery_key_yolov8"
detection_model = load_model(REPO_ID)
# Gradio UI
with gr.Blocks() as demo:
gr.Markdown("## Battery Key Detection - Image & Video")
with gr.Tabs():
with gr.TabItem("Image"):
with gr.Row():
img_input = gr.Image(type="pil", label="Upload Image")
img_output = gr.Image(type="pil", label="Predicted Image")
conf_slider_img = gr.Slider(0.1, 1.0, 0.5, step=0.05, label="Confidence Threshold")
iou_slider_img = gr.Slider(0.1, 1.0, 0.6, step=0.05, label="IoU Threshold")
run_btn_img = gr.Button("Run Detection on Image")
run_btn_img.click(fn=predict_image, inputs=[img_input, conf_slider_img, iou_slider_img], outputs=img_output)
with gr.TabItem("Video"):
with gr.Row():
vid_input = gr.Video(label="Upload Video")
vid_output = gr.Video(label="Predicted Video")
conf_slider_vid = gr.Slider(0.1, 1.0, 0.5, step=0.05, label="Confidence Threshold")
iou_slider_vid = gr.Slider(0.1, 1.0, 0.6, step=0.05, label="IoU Threshold")
run_btn_vid = gr.Button("Run Detection on Video")
run_btn_vid.click(fn=predict_video, inputs=[vid_input, conf_slider_vid, iou_slider_vid], outputs=vid_output)
demo.launch()
|