File size: 2,700 Bytes
0a61c75
 
 
 
a2dd7f7
0a61c75
69a61f5
0a61c75
a2dd7f7
0a61c75
 
69a61f5
966b22f
 
a2dd7f7
 
 
966b22f
a2dd7f7
bf51705
a2dd7f7
 
 
966b22f
 
 
a2dd7f7
 
966b22f
 
 
 
 
a2dd7f7
966b22f
a2dd7f7
bf51705
966b22f
a2dd7f7
 
0a61c75
a2dd7f7
669867a
 
 
a2dd7f7
 
 
 
 
 
 
 
 
 
 
 
966b22f
a2dd7f7
 
 
 
 
 
 
 
966b22f
a2dd7f7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
from ultralytics import YOLO
from PIL import Image
import gradio as gr
from huggingface_hub import snapshot_download
import tempfile
import os
import cv2

# Load the YOLO model from Hugging Face
def load_model(repo_id):
    download_dir = snapshot_download(repo_id)
    model_path = os.path.join(download_dir, "best.pt")
    return YOLO(model_path)

# Process image input
def predict_image(image, conf_threshold, iou_threshold):
    result = detection_model.predict(image, conf=conf_threshold, iou=iou_threshold)
    img_bgr = result[0].plot()
    return Image.fromarray(img_bgr[..., ::-1])

# Process video input
def predict_video(video_path, conf_threshold, iou_threshold):
    cap = cv2.VideoCapture(video_path)
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fps = cap.get(cv2.CAP_PROP_FPS)
    out_path = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name
    out_writer = cv2.VideoWriter(out_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))

    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break
        result = detection_model.predict(frame, conf=conf_threshold, iou=iou_threshold)
        annotated = result[0].plot()
        out_writer.write(annotated)

    cap.release()
    out_writer.release()
    return out_path

# Load model
REPO_ID = "Cedri/battery_key_yolov8"
detection_model = load_model(REPO_ID)

# Gradio UI
with gr.Blocks() as demo:
    gr.Markdown("## Battery Key Detection - Image & Video")
    with gr.Tabs():
        with gr.TabItem("Image"):
            with gr.Row():
                img_input = gr.Image(type="pil", label="Upload Image")
                img_output = gr.Image(type="pil", label="Predicted Image")
            conf_slider_img = gr.Slider(0.1, 1.0, 0.5, step=0.05, label="Confidence Threshold")
            iou_slider_img = gr.Slider(0.1, 1.0, 0.6, step=0.05, label="IoU Threshold")
            run_btn_img = gr.Button("Run Detection on Image")
            run_btn_img.click(fn=predict_image, inputs=[img_input, conf_slider_img, iou_slider_img], outputs=img_output)

        with gr.TabItem("Video"):
            with gr.Row():
                vid_input = gr.Video(label="Upload Video")
                vid_output = gr.Video(label="Predicted Video")
            conf_slider_vid = gr.Slider(0.1, 1.0, 0.5, step=0.05, label="Confidence Threshold")
            iou_slider_vid = gr.Slider(0.1, 1.0, 0.6, step=0.05, label="IoU Threshold")
            run_btn_vid = gr.Button("Run Detection on Video")
            run_btn_vid.click(fn=predict_video, inputs=[vid_input, conf_slider_vid, iou_slider_vid], outputs=vid_output)

demo.launch()