|
|
from ultralytics import YOLO |
|
|
from PIL import Image |
|
|
import gradio as gr |
|
|
from huggingface_hub import snapshot_download |
|
|
import tempfile |
|
|
import os |
|
|
import cv2 |
|
|
|
|
|
|
|
|
def load_model(repo_id): |
|
|
download_dir = snapshot_download(repo_id) |
|
|
model_path = os.path.join(download_dir, "best.pt") |
|
|
return YOLO(model_path) |
|
|
|
|
|
|
|
|
def predict_image(image, conf_threshold, iou_threshold): |
|
|
result = detection_model.predict(image, conf=conf_threshold, iou=iou_threshold) |
|
|
img_bgr = result[0].plot() |
|
|
return Image.fromarray(img_bgr[..., ::-1]) |
|
|
|
|
|
|
|
|
def predict_video(video_path, conf_threshold, iou_threshold): |
|
|
cap = cv2.VideoCapture(video_path) |
|
|
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) |
|
|
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) |
|
|
fps = cap.get(cv2.CAP_PROP_FPS) |
|
|
out_path = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name |
|
|
out_writer = cv2.VideoWriter(out_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height)) |
|
|
|
|
|
while cap.isOpened(): |
|
|
ret, frame = cap.read() |
|
|
if not ret: |
|
|
break |
|
|
result = detection_model.predict(frame, conf=conf_threshold, iou=iou_threshold) |
|
|
annotated = result[0].plot() |
|
|
out_writer.write(annotated) |
|
|
|
|
|
cap.release() |
|
|
out_writer.release() |
|
|
return out_path |
|
|
|
|
|
|
|
|
REPO_ID = "Cedri/battery_key_yolov8" |
|
|
detection_model = load_model(REPO_ID) |
|
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
gr.Markdown("## Battery Key Detection - Image & Video") |
|
|
with gr.Tabs(): |
|
|
with gr.TabItem("Image"): |
|
|
with gr.Row(): |
|
|
img_input = gr.Image(type="pil", label="Upload Image") |
|
|
img_output = gr.Image(type="pil", label="Predicted Image") |
|
|
conf_slider_img = gr.Slider(0.1, 1.0, 0.5, step=0.05, label="Confidence Threshold") |
|
|
iou_slider_img = gr.Slider(0.1, 1.0, 0.6, step=0.05, label="IoU Threshold") |
|
|
run_btn_img = gr.Button("Run Detection on Image") |
|
|
run_btn_img.click(fn=predict_image, inputs=[img_input, conf_slider_img, iou_slider_img], outputs=img_output) |
|
|
|
|
|
with gr.TabItem("Video"): |
|
|
with gr.Row(): |
|
|
vid_input = gr.Video(label="Upload Video") |
|
|
vid_output = gr.Video(label="Predicted Video") |
|
|
conf_slider_vid = gr.Slider(0.1, 1.0, 0.5, step=0.05, label="Confidence Threshold") |
|
|
iou_slider_vid = gr.Slider(0.1, 1.0, 0.6, step=0.05, label="IoU Threshold") |
|
|
run_btn_vid = gr.Button("Run Detection on Video") |
|
|
run_btn_vid.click(fn=predict_video, inputs=[vid_input, conf_slider_vid, iou_slider_vid], outputs=vid_output) |
|
|
|
|
|
demo.launch() |
|
|
|