from ultralytics import YOLO from PIL import Image import gradio as gr from huggingface_hub import snapshot_download import tempfile import os import cv2 # Load the YOLO model from Hugging Face def load_model(repo_id): download_dir = snapshot_download(repo_id) model_path = os.path.join(download_dir, "best.pt") return YOLO(model_path) # Process image input def predict_image(image, conf_threshold, iou_threshold): result = detection_model.predict(image, conf=conf_threshold, iou=iou_threshold) img_bgr = result[0].plot() return Image.fromarray(img_bgr[..., ::-1]) # Process video input def predict_video(video_path, conf_threshold, iou_threshold): cap = cv2.VideoCapture(video_path) width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) fps = cap.get(cv2.CAP_PROP_FPS) out_path = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name out_writer = cv2.VideoWriter(out_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height)) while cap.isOpened(): ret, frame = cap.read() if not ret: break result = detection_model.predict(frame, conf=conf_threshold, iou=iou_threshold) annotated = result[0].plot() out_writer.write(annotated) cap.release() out_writer.release() return out_path # Load model REPO_ID = "Cedri/battery_key_yolov8" detection_model = load_model(REPO_ID) # Gradio UI with gr.Blocks() as demo: gr.Markdown("## Battery Key Detection - Image & Video") with gr.Tabs(): with gr.TabItem("Image"): with gr.Row(): img_input = gr.Image(type="pil", label="Upload Image") img_output = gr.Image(type="pil", label="Predicted Image") conf_slider_img = gr.Slider(0.1, 1.0, 0.5, step=0.05, label="Confidence Threshold") iou_slider_img = gr.Slider(0.1, 1.0, 0.6, step=0.05, label="IoU Threshold") run_btn_img = gr.Button("Run Detection on Image") run_btn_img.click(fn=predict_image, inputs=[img_input, conf_slider_img, iou_slider_img], outputs=img_output) with gr.TabItem("Video"): with gr.Row(): vid_input = gr.Video(label="Upload Video") vid_output = gr.Video(label="Predicted Video") conf_slider_vid = gr.Slider(0.1, 1.0, 0.5, step=0.05, label="Confidence Threshold") iou_slider_vid = gr.Slider(0.1, 1.0, 0.6, step=0.05, label="IoU Threshold") run_btn_vid = gr.Button("Run Detection on Video") run_btn_vid.click(fn=predict_video, inputs=[vid_input, conf_slider_vid, iou_slider_vid], outputs=vid_output) demo.launch()