Nadun102 commited on
Commit
96a9ddb
·
verified ·
1 Parent(s): 7bae65e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +102 -0
app.py CHANGED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from inference_sdk import InferenceHTTPClient
3
+ import cv2
4
+ import numpy as np
5
+ import tempfile
6
+ import os
7
+
8
+ # Initialize Roboflow client
9
+ CLIENT = InferenceHTTPClient(
10
+ api_url="https://serverless.roboflow.com",
11
+ api_key="DIAhXQf6AUsyM1PRfdFa"
12
+ )
13
+
14
+ MODEL_ID = "garbage-detection-pbcjq/7"
15
+
16
+ # ----------------------------
17
+ # Prediction Functions
18
+ # ----------------------------
19
+
20
+ def predict_image(image):
21
+ """
22
+ Accepts a PIL image or NumPy array, sends to Roboflow, returns image with bounding boxes.
23
+ """
24
+ # Save image temporarily
25
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".jpg")
26
+ cv2.imwrite(temp_file.name, cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR))
27
+
28
+ # Make prediction
29
+ result = CLIENT.infer(temp_file.name, model_id=MODEL_ID)
30
+
31
+ # Draw bounding boxes
32
+ img = np.array(image).copy()
33
+ for pred in result.get("predictions", []):
34
+ x1, y1, x2, y2 = pred["bbox"]["x"], pred["bbox"]["y"], pred["bbox"]["width"], pred["bbox"]["height"]
35
+ x2 += x1
36
+ y2 += y1
37
+ label = f"{pred['class']} {pred['confidence']:.2f}"
38
+ cv2.rectangle(img, (int(x1), int(y1)), (int(x2), int(y2)), (0,255,0), 2)
39
+ cv2.putText(img, label, (int(x1), int(y1)-10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0,255,0), 2)
40
+
41
+ os.unlink(temp_file.name)
42
+ return img
43
+
44
+ def predict_video(video_file):
45
+ """
46
+ Accepts video path, returns video path with bounding boxes on each frame.
47
+ """
48
+ temp_output = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4").name
49
+ cap = cv2.VideoCapture(video_file)
50
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
51
+ fps = int(cap.get(cv2.CAP_PROP_FPS))
52
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
53
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
54
+
55
+ out = cv2.VideoWriter(temp_output, fourcc, fps, (width, height))
56
+
57
+ while True:
58
+ ret, frame = cap.read()
59
+ if not ret:
60
+ break
61
+ # Save frame temporarily for prediction
62
+ temp_frame_file = tempfile.NamedTemporaryFile(delete=False, suffix=".jpg").name
63
+ cv2.imwrite(temp_frame_file, frame)
64
+ result = CLIENT.infer(temp_frame_file, model_id=MODEL_ID)
65
+ os.unlink(temp_frame_file)
66
+
67
+ # Draw predictions
68
+ for pred in result.get("predictions", []):
69
+ x1, y1, w, h = pred["bbox"]["x"], pred["bbox"]["y"], pred["bbox"]["width"], pred["bbox"]["height"]
70
+ x2, y2 = x1 + w, y1 + h
71
+ label = f"{pred['class']} {pred['confidence']:.2f}"
72
+ cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), (0,255,0), 2)
73
+ cv2.putText(frame, label, (int(x1), int(y1)-10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0,255,0), 2)
74
+
75
+ out.write(frame)
76
+
77
+ cap.release()
78
+ out.release()
79
+ return temp_output
80
+
81
+ # ----------------------------
82
+ # Gradio Interface
83
+ # ----------------------------
84
+
85
+ with gr.Blocks() as demo:
86
+ gr.Markdown("## 🗑 Garbage Detection App (Image & Video)")
87
+ gr.Markdown("Upload an image or video to detect objects using Roboflow.")
88
+
89
+ with gr.Tabs():
90
+ with gr.Tab("Image"):
91
+ image_input = gr.Image(type="pil")
92
+ image_output = gr.Image()
93
+ image_button = gr.Button("Predict Image")
94
+ image_button.click(predict_image, inputs=image_input, outputs=image_output)
95
+
96
+ with gr.Tab("Video"):
97
+ video_input = gr.Video()
98
+ video_output = gr.Video()
99
+ video_button = gr.Button("Predict Video")
100
+ video_button.click(predict_video, inputs=video_input, outputs=video_output)
101
+
102
+ demo.launch()