JLtan1024 commited on
Commit
8dfa70e
·
verified ·
1 Parent(s): 03082e3

gradio-webrtc

Browse files
Files changed (1) hide show
  1. app.py +127 -125
app.py CHANGED
@@ -1,4 +1,5 @@
1
  import gradio as gr
 
2
  import numpy as np
3
  from PIL import Image, ImageDraw, ImageFont
4
  from collections import Counter
@@ -6,6 +7,8 @@ import time
6
  import tempfile
7
  from ultralytics import YOLO
8
  import cv2
 
 
9
 
10
  # Constants
11
  COIN_CLASS_ID = 11 # 10sen coin
@@ -105,99 +108,133 @@ def non_max_suppression(detections, iou_threshold):
105
 
106
  return [detections[i] for i in keep_indices]
107
 
108
- def process_frame(frame, iou_threshold, confidence_threshold, show_detections, px_to_mm_ratio=None):
109
- """Process a single frame and return annotated image and detection data"""
110
- results = model(frame, conf=confidence_threshold)
 
 
 
 
 
 
111
 
112
- if not results:
113
- return frame, [], px_to_mm_ratio
 
 
 
 
114
 
115
- result = results[0]
116
- filtered_detections = non_max_suppression(result.obb, iou_threshold)
117
-
118
- pil_image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
119
- draw = ImageDraw.Draw(pil_image)
 
 
 
 
 
120
 
121
- try:
122
- font = ImageFont.truetype("arial.ttf", LABEL_FONT_SIZE)
123
- except:
124
- font = ImageFont.load_default()
125
- if hasattr(font, 'size'):
126
- font.size = LABEL_FONT_SIZE
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
 
128
- detected_objects = []
129
- current_px_to_mm_ratio = px_to_mm_ratio
130
-
131
- # Find coin for scaling
132
- if current_px_to_mm_ratio is None:
 
 
 
 
 
 
 
 
 
 
133
  for detection in filtered_detections:
134
- if len(detection.cls) > 0 and int(detection.cls[0]) == COIN_CLASS_ID and len(detection.xywhr) > 0:
135
- coin_xywhr = detection.xywhr[0]
136
- width_px = coin_xywhr[2]
137
- height_px = coin_xywhr[3]
138
- avg_px_diameter = (width_px + height_px) / 2
139
- if avg_px_diameter > 0:
140
- current_px_to_mm_ratio = COIN_DIAMETER_MM / avg_px_diameter
141
- break
142
 
143
- # Draw detections
144
- for detection in filtered_detections:
145
- if len(detection.cls) > 0 and len(detection.xywhr) > 0 and len(detection.xyxy) > 0:
146
- class_id = int(detection.cls[0])
147
- confidence = detection.conf[0]
148
- x1, y1, x2, y2 = map(int, detection.xyxy[0])
149
- class_name = CLASS_NAMES.get(class_id, f"Class {int(class_id)}")
150
- color = CATEGORY_COLORS.get(class_name, (0, 255, 0))
151
 
152
- label_text = f"{class_name}"
153
- if class_id != COIN_CLASS_ID:
154
- detected_objects.append(class_name)
 
 
 
 
 
 
 
 
 
 
 
 
155
 
156
- if class_id == COIN_CLASS_ID and current_px_to_mm_ratio:
157
- diameter_px = (x2 - x1 + y2 - y1) / 2
158
- diameter_mm = diameter_px * current_px_to_mm_ratio
159
- label_text += f", Dia: {diameter_mm:.2f}mm"
160
- elif class_id != COIN_CLASS_ID and current_px_to_mm_ratio:
161
- xywhr = detection.xywhr[0]
162
- width_px = xywhr[2]
163
- height_px = xywhr[3]
164
- length_px = max(width_px, height_px)
165
- length_mm = length_px * current_px_to_mm_ratio
166
- label_text += f", Length: {length_mm:.2f}mm"
167
- elif class_id != COIN_CLASS_ID:
168
- label_text += ", Length: N/A (No Coin)"
169
- elif class_id == COIN_CLASS_ID:
170
- label_text += ", Dia: N/A (No Ratio)"
171
 
172
- if show_detections:
173
- draw.rectangle([(x1, y1), (x2, y2)], outline=color, width=BORDER_WIDTH)
174
- text_width, text_height = get_text_size(draw, label_text, font)
175
- draw.rectangle([(x1, y1 - text_height - 5), (x1 + text_width + 5, y1)], fill=color)
176
- draw.text((x1 + 2, y1 - text_height - 3), label_text, fill=(255, 255, 255), font=font)
177
 
178
- return np.array(pil_image), detected_objects, current_px_to_mm_ratio
 
 
179
 
180
  def process_image(input_image, iou_threshold, confidence_threshold, show_detections, show_summary):
181
  frame = np.array(input_image)
182
- processed_frame, detected_objects, _ = process_frame(frame, iou_threshold, confidence_threshold, show_detections)
 
 
 
 
183
 
184
  output_image = Image.fromarray(processed_frame)
185
 
186
- summary = ""
187
- if show_summary and detected_objects:
188
- screw_counts = Counter(detected_objects)
189
- summary = "Detection Summary:\n"
190
- for name, count in screw_counts.items():
191
- summary += f"- {name}: {count}\n"
192
- elif show_summary:
193
- summary = "No screws or nuts detected."
194
 
195
  return output_image, summary
196
 
197
  def process_video(video_path, iou_threshold, confidence_threshold, show_detections, show_summary):
198
  cap = cv2.VideoCapture(video_path)
199
- px_to_mm_ratio = None
200
- all_detected_objects = []
 
 
201
 
202
  frames = []
203
  while cap.isOpened():
@@ -205,60 +242,15 @@ def process_video(video_path, iou_threshold, confidence_threshold, show_detectio
205
  if not ret:
206
  break
207
 
208
- processed_frame, detected_objects, px_to_mm_ratio = process_frame(
209
- frame, iou_threshold, confidence_threshold, show_detections, px_to_mm_ratio
210
- )
211
-
212
- if detected_objects:
213
- all_detected_objects.extend(detected_objects)
214
-
215
  frames.append(processed_frame)
216
 
217
  cap.release()
218
 
219
- summary = ""
220
- if show_summary and all_detected_objects:
221
- screw_counts = Counter(all_detected_objects)
222
- summary = "Detection Summary:\n"
223
- for name, count in screw_counts.items():
224
- summary += f"- {name}: {count}\n"
225
- elif show_summary:
226
- summary = "No screws or nuts detected."
227
 
228
  return frames, summary
229
 
230
- def webcam_capture(iou_threshold, confidence_threshold, show_detections, show_summary):
231
- cap = cv2.VideoCapture(0)
232
- px_to_mm_ratio = None
233
- all_detected_objects = []
234
-
235
- while True:
236
- ret, frame = cap.read()
237
- if not ret:
238
- break
239
-
240
- processed_frame, detected_objects, px_to_mm_ratio = process_frame(
241
- frame, iou_threshold, confidence_threshold, show_detections, px_to_mm_ratio
242
- )
243
-
244
- if detected_objects:
245
- all_detected_objects.extend(detected_objects)
246
-
247
- yield cv2.cvtColor(processed_frame, cv2.COLOR_BGR2RGB)
248
-
249
- cap.release()
250
-
251
- summary = ""
252
- if show_summary and all_detected_objects:
253
- screw_counts = Counter(all_detected_objects)
254
- summary = "Detection Summary:\n"
255
- for name, count in screw_counts.items():
256
- summary += f"- {name}: {count}\n"
257
- elif show_summary:
258
- summary = "No screws or nuts detected."
259
-
260
- yield None, summary
261
-
262
  # Gradio Interface
263
  with gr.Blocks(title="Screw Detection and Measurement") as demo:
264
  gr.Markdown("# 🔍 Screw Detection and Measurement (YOLOv11 OBB)")
@@ -308,15 +300,25 @@ with gr.Blocks(title="Screw Detection and Measurement") as demo:
308
  webcam_conf = gr.Slider(label="Confidence Threshold", minimum=0.0, maximum=1.0, value=0.5, step=0.05)
309
  webcam_show_det = gr.Checkbox(label="Show Detections", value=True)
310
  webcam_show_sum = gr.Checkbox(label="Show Summary", value=True)
311
- webcam_button = gr.Button("Start Webcam")
312
  with gr.Column():
313
- webcam_output = gr.Image(label="Live Detection", streaming=True)
 
 
 
 
 
314
  webcam_summary = gr.Textbox(label="Summary", interactive=False)
 
 
 
 
 
 
315
 
316
- webcam_button.click(
317
- webcam_capture,
318
- inputs=[webcam_iou, webcam_conf, webcam_show_det, webcam_show_sum],
319
- outputs=[webcam_output, webcam_summary]
320
  )
321
 
322
  demo.launch()
 
1
  import gradio as gr
2
+ from gradio_webrtc import WebRTC
3
  import numpy as np
4
  from PIL import Image, ImageDraw, ImageFont
5
  from collections import Counter
 
7
  import tempfile
8
  from ultralytics import YOLO
9
  import cv2
10
+ import av
11
+ import threading
12
 
13
  # Constants
14
  COIN_CLASS_ID = 11 # 10sen coin
 
108
 
109
  return [detections[i] for i in keep_indices]
110
 
111
+ class VideoProcessor:
112
+ def __init__(self):
113
+ self.px_to_mm_ratio = None
114
+ self.detected_objects = []
115
+ self.lock = threading.Lock()
116
+ self.show_detections = True
117
+ self.show_summary = True
118
+ self.iou_threshold = 0.7
119
+ self.confidence_threshold = 0.5
120
 
121
+ def update_settings(self, iou_threshold, confidence_threshold, show_detections, show_summary):
122
+ with self.lock:
123
+ self.iou_threshold = iou_threshold
124
+ self.confidence_threshold = confidence_threshold
125
+ self.show_detections = show_detections
126
+ self.show_summary = show_summary
127
 
128
+ def get_summary(self):
129
+ with self.lock:
130
+ if not self.show_summary or not self.detected_objects:
131
+ return "No screws or nuts detected yet."
132
+
133
+ screw_counts = Counter(self.detected_objects)
134
+ summary_text = "Detection Summary:\n"
135
+ for name, count in screw_counts.items():
136
+ summary_text += f"- {name}: {count}\n"
137
+ return summary_text
138
 
139
+ def process_frame(self, frame):
140
+ frame = frame.to_ndarray(format="bgr24")
141
+
142
+ results = model(frame, conf=self.confidence_threshold)
143
+
144
+ if not results:
145
+ return frame, []
146
+
147
+ result = results[0]
148
+ filtered_detections = non_max_suppression(result.obb, self.iou_threshold)
149
+
150
+ pil_image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
151
+ draw = ImageDraw.Draw(pil_image)
152
+
153
+ try:
154
+ font = ImageFont.truetype("arial.ttf", LABEL_FONT_SIZE)
155
+ except:
156
+ font = ImageFont.load_default()
157
+ if hasattr(font, 'size'):
158
+ font.size = LABEL_FONT_SIZE
159
 
160
+ frame_detected_objects = []
161
+
162
+ # Find coin for scaling
163
+ if self.px_to_mm_ratio is None:
164
+ for detection in filtered_detections:
165
+ if len(detection.cls) > 0 and int(detection.cls[0]) == COIN_CLASS_ID and len(detection.xywhr) > 0:
166
+ coin_xywhr = detection.xywhr[0]
167
+ width_px = coin_xywhr[2]
168
+ height_px = coin_xywhr[3]
169
+ avg_px_diameter = (width_px + height_px) / 2
170
+ if avg_px_diameter > 0:
171
+ self.px_to_mm_ratio = COIN_DIAMETER_MM / avg_px_diameter
172
+ break
173
+
174
+ # Draw detections
175
  for detection in filtered_detections:
176
+ if len(detection.cls) > 0 and len(detection.xywhr) > 0 and len(detection.xyxy) > 0:
177
+ class_id = int(detection.cls[0])
178
+ confidence = detection.conf[0]
179
+ x1, y1, x2, y2 = map(int, detection.xyxy[0])
180
+ class_name = CLASS_NAMES.get(class_id, f"Class {int(class_id)}")
181
+ color = CATEGORY_COLORS.get(class_name, (0, 255, 0))
 
 
182
 
183
+ label_text = f"{class_name}"
184
+ if class_id != COIN_CLASS_ID:
185
+ frame_detected_objects.append(class_name)
 
 
 
 
 
186
 
187
+ if class_id == COIN_CLASS_ID and self.px_to_mm_ratio:
188
+ diameter_px = (x2 - x1 + y2 - y1) / 2
189
+ diameter_mm = diameter_px * self.px_to_mm_ratio
190
+ label_text += f", Dia: {diameter_mm:.2f}mm"
191
+ elif class_id != COIN_CLASS_ID and self.px_to_mm_ratio:
192
+ xywhr = detection.xywhr[0]
193
+ width_px = xywhr[2]
194
+ height_px = xywhr[3]
195
+ length_px = max(width_px, height_px)
196
+ length_mm = length_px * self.px_to_mm_ratio
197
+ label_text += f", Length: {length_mm:.2f}mm"
198
+ elif class_id != COIN_CLASS_ID:
199
+ label_text += ", Length: N/A (No Coin)"
200
+ elif class_id == COIN_CLASS_ID:
201
+ label_text += ", Dia: N/A (No Ratio)"
202
 
203
+ if self.show_detections:
204
+ draw.rectangle([(x1, y1), (x2, y2)], outline=color, width=BORDER_WIDTH)
205
+ text_width, text_height = get_text_size(draw, label_text, font)
206
+ draw.rectangle([(x1, y1 - text_height - 5), (x1 + text_width + 5, y1)], fill=color)
207
+ draw.text((x1 + 2, y1 - text_height - 3), label_text, fill=(255, 255, 255), font=font)
 
 
 
 
 
 
 
 
 
 
208
 
209
+ with self.lock:
210
+ self.detected_objects.extend(frame_detected_objects)
211
+
212
+ return np.array(pil_image)
 
213
 
214
+ def recv(self, frame):
215
+ processed_frame = self.process_frame(frame)
216
+ return av.VideoFrame.from_ndarray(processed_frame, format="bgr24")
217
 
218
  def process_image(input_image, iou_threshold, confidence_threshold, show_detections, show_summary):
219
  frame = np.array(input_image)
220
+
221
+ # Create a temporary processor for image processing
222
+ processor = VideoProcessor()
223
+ processor.update_settings(iou_threshold, confidence_threshold, show_detections, show_summary)
224
+ processed_frame = processor.process_frame(av.VideoFrame.from_ndarray(frame, format="bgr24"))
225
 
226
  output_image = Image.fromarray(processed_frame)
227
 
228
+ summary = processor.get_summary()
 
 
 
 
 
 
 
229
 
230
  return output_image, summary
231
 
232
  def process_video(video_path, iou_threshold, confidence_threshold, show_detections, show_summary):
233
  cap = cv2.VideoCapture(video_path)
234
+
235
+ # Create a processor for video processing
236
+ processor = VideoProcessor()
237
+ processor.update_settings(iou_threshold, confidence_threshold, show_detections, show_summary)
238
 
239
  frames = []
240
  while cap.isOpened():
 
242
  if not ret:
243
  break
244
 
245
+ processed_frame = processor.process_frame(av.VideoFrame.from_ndarray(frame, format="bgr24"))
 
 
 
 
 
 
246
  frames.append(processed_frame)
247
 
248
  cap.release()
249
 
250
+ summary = processor.get_summary()
 
 
 
 
 
 
 
251
 
252
  return frames, summary
253
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
254
  # Gradio Interface
255
  with gr.Blocks(title="Screw Detection and Measurement") as demo:
256
  gr.Markdown("# 🔍 Screw Detection and Measurement (YOLOv11 OBB)")
 
300
  webcam_conf = gr.Slider(label="Confidence Threshold", minimum=0.0, maximum=1.0, value=0.5, step=0.05)
301
  webcam_show_det = gr.Checkbox(label="Show Detections", value=True)
302
  webcam_show_sum = gr.Checkbox(label="Show Summary", value=True)
303
+ settings_button = gr.Button("Update Settings")
304
  with gr.Column():
305
+ webrtc_ctx = WebRTC(
306
+ mode="sendonly",
307
+ audio=False,
308
+ video_processor_factory=VideoProcessor,
309
+ key="webcam-detection"
310
+ )
311
  webcam_summary = gr.Textbox(label="Summary", interactive=False)
312
+ refresh_button = gr.Button("Refresh Summary")
313
+
314
+ settings_button.click(
315
+ fn=lambda iou, conf, det, summ: webrtc_ctx.video_processor.update_settings(iou, conf, det, summ),
316
+ inputs=[webcam_iou, webcam_conf, webcam_show_det, webcam_show_sum]
317
+ )
318
 
319
+ refresh_button.click(
320
+ fn=lambda: webrtc_ctx.video_processor.get_summary(),
321
+ outputs=webcam_summary
 
322
  )
323
 
324
  demo.launch()