Cedri commited on
Commit
a2dd7f7
·
verified ·
1 Parent(s): bde702c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -53
app.py CHANGED
@@ -2,83 +2,67 @@ from ultralytics import YOLO
2
  from PIL import Image
3
  import gradio as gr
4
  from huggingface_hub import snapshot_download
 
5
  import os
6
  import cv2
7
- import tempfile
8
- import numpy as np
9
 
10
- # === Load model from Hugging Face ===
11
  def load_model(repo_id):
12
  download_dir = snapshot_download(repo_id)
13
- print("Model downloaded to:", download_dir)
14
  model_path = os.path.join(download_dir, "best.pt")
15
  return YOLO(model_path)
16
 
17
- # === Prediction functions ===
18
- def predict_image(pil_image, conf):
19
- result = detection_model.predict(pil_image, conf=conf, iou=0.6)
20
  img_bgr = result[0].plot()
21
- output = Image.fromarray(result[0].plot())
22
- return output
23
-
24
- def predict_video(video_file, conf):
25
- cap = cv2.VideoCapture(video_file)
26
- if not cap.isOpened():
27
- raise IOError("Cannot open video file")
28
 
 
 
 
29
  width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
30
  height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
31
  fps = cap.get(cv2.CAP_PROP_FPS)
32
-
33
- temp_output = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
34
- out = cv2.VideoWriter(temp_output.name, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
35
 
36
  while cap.isOpened():
37
  ret, frame = cap.read()
38
  if not ret:
39
  break
40
-
41
- result = detection_model.predict(frame, conf=conf, iou=0.6)
42
  annotated = result[0].plot()
43
- out.write(annotated)
44
 
45
  cap.release()
46
- out.release()
47
- return temp_output.name
48
 
49
- # === Load model ===
50
  REPO_ID = "Cedri/battery_key_yolov8"
51
  detection_model = load_model(REPO_ID)
52
 
53
- # === Gradio UI ===
54
- def image_interface(image, conf_threshold):
55
- return predict_image(image, conf=conf_threshold)
56
-
57
- def video_interface(video, conf_threshold):
58
- return predict_video(video, conf=conf_threshold)
59
-
60
- image_tab = gr.Interface(
61
- fn=image_interface,
62
- inputs=[
63
- gr.Image(label="Upload Image"),
64
- gr.Slider(minimum=0.1, maximum=1.0, step=0.05, value=0.7, label="Confidence Threshold")
65
- ],
66
- outputs=gr.Image(label="Detected Image"),
67
- title="Battery Key Detection (Image)"
68
- )
69
 
70
- video_tab = gr.Interface(
71
- fn=video_interface,
72
- inputs=[
73
- gr.Video(label="Upload Video"),
74
- gr.Slider(minimum=0.1, maximum=1.0, step=0.05, value=0.7, label="Confidence Threshold")
75
- ],
76
- outputs=gr.Video(label="Detected Video"),
77
- title="Battery Key Detection (Video)"
78
- )
79
 
80
- # === Launch with tabs ===
81
- gr.TabbedInterface(
82
- [image_tab, video_tab],
83
- tab_names=["Image", "Video"]
84
- ).launch()
 
2
  from PIL import Image
3
  import gradio as gr
4
  from huggingface_hub import snapshot_download
5
+ import tempfile
6
  import os
7
  import cv2
 
 
8
 
9
+ # Load the YOLO model from Hugging Face
10
  def load_model(repo_id):
11
  download_dir = snapshot_download(repo_id)
 
12
  model_path = os.path.join(download_dir, "best.pt")
13
  return YOLO(model_path)
14
 
15
+ # Process image input
16
+ def predict_image(image, conf_threshold, iou_threshold):
17
+ result = detection_model.predict(image, conf=conf_threshold, iou=iou_threshold)
18
  img_bgr = result[0].plot()
19
+ return Image.fromarray(img_bgr[..., ::-1])
 
 
 
 
 
 
20
 
21
+ # Process video input
22
+ def predict_video(video_path, conf_threshold, iou_threshold):
23
+ cap = cv2.VideoCapture(video_path)
24
  width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
25
  height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
26
  fps = cap.get(cv2.CAP_PROP_FPS)
27
+ out_path = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name
28
+ out_writer = cv2.VideoWriter(out_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
 
29
 
30
  while cap.isOpened():
31
  ret, frame = cap.read()
32
  if not ret:
33
  break
34
+ result = detection_model.predict(frame, conf=conf_threshold, iou=iou_threshold)
 
35
  annotated = result[0].plot()
36
+ out_writer.write(annotated)
37
 
38
  cap.release()
39
+ out_writer.release()
40
+ return out_path
41
 
42
+ # Load model
43
  REPO_ID = "Cedri/battery_key_yolov8"
44
  detection_model = load_model(REPO_ID)
45
 
46
+ # Gradio UI
47
+ with gr.Blocks() as demo:
48
+ gr.Markdown("## Battery Key Detection - Image & Video")
49
+ with gr.Tabs():
50
+ with gr.TabItem("Image"):
51
+ with gr.Row():
52
+ img_input = gr.Image(type="pil", label="Upload Image")
53
+ img_output = gr.Image(type="pil", label="Predicted Image")
54
+ conf_slider_img = gr.Slider(0.1, 1.0, 0.5, step=0.05, label="Confidence Threshold")
55
+ iou_slider_img = gr.Slider(0.1, 1.0, 0.6, step=0.05, label="IoU Threshold")
56
+ run_btn_img = gr.Button("Run Detection on Image")
57
+ run_btn_img.click(fn=predict_image, inputs=[img_input, conf_slider_img, iou_slider_img], outputs=img_output)
 
 
 
 
58
 
59
+ with gr.TabItem("Video"):
60
+ with gr.Row():
61
+ vid_input = gr.Video(label="Upload Video")
62
+ vid_output = gr.Video(label="Predicted Video")
63
+ conf_slider_vid = gr.Slider(0.1, 1.0, 0.5, step=0.05, label="Confidence Threshold")
64
+ iou_slider_vid = gr.Slider(0.1, 1.0, 0.6, step=0.05, label="IoU Threshold")
65
+ run_btn_vid = gr.Button("Run Detection on Video")
66
+ run_btn_vid.click(fn=predict_video, inputs=[vid_input, conf_slider_vid, iou_slider_vid], outputs=vid_output)
 
67
 
68
+ demo.launch()