JLtan1024 commited on
Commit
b1c18b2
·
verified ·
1 Parent(s): 93e262c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +97 -37
app.py CHANGED
@@ -3,10 +3,20 @@ import numpy as np
3
  from PIL import Image, ImageDraw, ImageFont
4
  from collections import Counter
5
  import time
6
- import tempfile
7
  from ultralytics import YOLO
8
  import cv2
9
- import os
 
 
 
 
 
 
 
 
 
 
10
 
11
  # Constants
12
  COIN_CLASS_ID = 11 # 10sen coin
@@ -114,7 +124,7 @@ def non_max_suppression(detections, iou_threshold):
114
 
115
  return [detections[i] for i in keep_indices]
116
 
117
- class VideoProcessor:
118
  def __init__(self):
119
  self.px_to_mm_ratio = None
120
  self.detected_objects = []
@@ -147,12 +157,12 @@ class VideoProcessor:
147
  if isinstance(frame, np.ndarray):
148
  frame_np = frame
149
  else:
150
- # This handles the case if frame comes from Gradio's webcam which is already numpy
151
  frame_np = np.array(frame)
152
 
153
  results = model(frame_np, conf=self.confidence_threshold)
154
 
155
- if not results:
156
  return frame_np, []
157
 
158
  result = results[0]
@@ -226,6 +236,30 @@ class VideoProcessor:
226
  # Convert back to BGR for OpenCV operations
227
  return cv2.cvtColor(processed_img, cv2.COLOR_RGB2BGR), frame_detected_objects
228
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
  def process_image(input_image, iou_threshold, confidence_threshold, show_detections, show_summary):
230
  if input_image is None:
231
  return None, "Please upload an image first."
@@ -237,7 +271,7 @@ def process_image(input_image, iou_threshold, confidence_threshold, show_detecti
237
  frame = input_image
238
 
239
  # Create a temporary processor for image processing
240
- processor = VideoProcessor()
241
  processor.update_settings(iou_threshold, confidence_threshold, show_detections, show_summary)
242
  processed_frame, _ = processor.process_frame(frame)
243
 
@@ -254,7 +288,7 @@ def process_video(video_path, iou_threshold, confidence_threshold, show_detectio
254
 
255
  try:
256
  # Create a processor for video processing
257
- processor = VideoProcessor()
258
  processor.update_settings(iou_threshold, confidence_threshold, show_detections, show_summary)
259
 
260
  cap = cv2.VideoCapture(video_path)
@@ -288,23 +322,20 @@ def process_video(video_path, iou_threshold, confidence_threshold, show_detectio
288
  except Exception as e:
289
  return [], f"Error processing video: {str(e)}"
290
 
291
- def process_webcam(frame, iou_threshold, confidence_threshold, show_detections, show_summary):
292
- if frame is None:
293
- return None
294
-
295
- # Convert from RGB to BGR for processing
296
- frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
297
-
298
- # Create a temporary processor for webcam processing
299
- processor = VideoProcessor()
300
- processor.update_settings(iou_threshold, confidence_threshold, show_detections, show_summary)
301
-
302
- # Process the frame
303
- processed_frame, _ = processor.process_frame(frame_bgr)
304
-
305
- # Convert back to RGB for Gradio
306
- processed_frame_rgb = cv2.cvtColor(processed_frame, cv2.COLOR_BGR2RGB)
307
- return processed_frame_rgb
308
 
309
  # Gradio Interface
310
  with gr.Blocks(title="Screw Detection and Measurement") as demo:
@@ -348,28 +379,57 @@ with gr.Blocks(title="Screw Detection and Measurement") as demo:
348
  outputs=[video_output, video_summary]
349
  )
350
 
351
- with gr.Tab("Webcam"):
352
  with gr.Row():
353
- with gr.Column():
354
  webcam_iou = gr.Slider(label="IoU Threshold (NMS)", minimum=0.0, maximum=1.0, value=0.7, step=0.05)
355
  webcam_conf = gr.Slider(label="Confidence Threshold", minimum=0.0, maximum=1.0, value=0.5, step=0.05)
356
  webcam_show_det = gr.Checkbox(label="Show Detections", value=True)
357
  webcam_show_sum = gr.Checkbox(label="Show Summary", value=True)
358
- with gr.Column():
359
- # Use the compatible webcam syntax - this is compatible with older versions of gradio
360
- webcam_input = gr.Image(label="Live Camera")
361
- webcam_output = gr.Image(label="Processed Output")
362
- webcam_button = gr.Button("Process Webcam Image")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
363
 
364
- # Use click instead of change for webcam processing
365
- webcam_button.click(
366
- fn=process_webcam,
367
- inputs=[webcam_input, webcam_iou, webcam_conf, webcam_show_det, webcam_show_sum],
368
- outputs=webcam_output
 
 
 
 
 
 
 
369
  )
370
 
371
  # Add warning about model loading
372
  if model is None:
373
  gr.Warning("Model could not be loaded. Please ensure 'yolo11-obb12classes.pt' is available.")
374
 
375
- demo.launch()
 
3
  from PIL import Image, ImageDraw, ImageFont
4
  from collections import Counter
5
  import time
6
+ import os
7
  from ultralytics import YOLO
8
  import cv2
9
+ from gradio_client.documentation import document, DocumentedType
10
+
11
+ # Import WebRTC components
12
+ from gradio_webrtc import (
13
+ RTCConfiguration,
14
+ WebRtcStreamerContext,
15
+ WebRtcMode,
16
+ WebRtcStreamer,
17
+ VideoTransformerBase,
18
+ VideoTransformerContext,
19
+ )
20
 
21
  # Constants
22
  COIN_CLASS_ID = 11 # 10sen coin
 
124
 
125
  return [detections[i] for i in keep_indices]
126
 
127
+ class ScrewDetectionProcessor:
128
  def __init__(self):
129
  self.px_to_mm_ratio = None
130
  self.detected_objects = []
 
157
  if isinstance(frame, np.ndarray):
158
  frame_np = frame
159
  else:
160
+ # This handles the case if frame comes from other sources
161
  frame_np = np.array(frame)
162
 
163
  results = model(frame_np, conf=self.confidence_threshold)
164
 
165
+ if not results or len(results) == 0:
166
  return frame_np, []
167
 
168
  result = results[0]
 
236
  # Convert back to BGR for OpenCV operations
237
  return cv2.cvtColor(processed_img, cv2.COLOR_RGB2BGR), frame_detected_objects
238
 
239
+ # WebRTC Video Transformer
240
+ class ScrewDetectionTransformer(VideoTransformerBase):
241
+ def __init__(self):
242
+ self.processor = ScrewDetectionProcessor()
243
+ self.summary_text = "No detections yet."
244
+
245
+ def update_settings(self, iou_threshold, confidence_threshold, show_detections, show_summary):
246
+ self.processor.update_settings(
247
+ iou_threshold=iou_threshold,
248
+ confidence_threshold=confidence_threshold,
249
+ show_detections=show_detections,
250
+ show_summary=show_summary
251
+ )
252
+
253
+ def get_summary(self):
254
+ return self.processor.get_summary()
255
+
256
+ def transform(self, frame):
257
+ # Process frame will be called on each video frame
258
+ img = frame.to_ndarray(format="bgr24")
259
+ processed_frame, _ = self.processor.process_frame(img)
260
+ self.summary_text = self.processor.get_summary()
261
+ return processed_frame
262
+
263
  def process_image(input_image, iou_threshold, confidence_threshold, show_detections, show_summary):
264
  if input_image is None:
265
  return None, "Please upload an image first."
 
271
  frame = input_image
272
 
273
  # Create a temporary processor for image processing
274
+ processor = ScrewDetectionProcessor()
275
  processor.update_settings(iou_threshold, confidence_threshold, show_detections, show_summary)
276
  processed_frame, _ = processor.process_frame(frame)
277
 
 
288
 
289
  try:
290
  # Create a processor for video processing
291
+ processor = ScrewDetectionProcessor()
292
  processor.update_settings(iou_threshold, confidence_threshold, show_detections, show_summary)
293
 
294
  cap = cv2.VideoCapture(video_path)
 
322
  except Exception as e:
323
  return [], f"Error processing video: {str(e)}"
324
 
325
+ def update_webrtc_settings(iou_threshold, confidence_threshold, show_detections, show_summary, webrtc_ctx):
326
+ if webrtc_ctx and webrtc_ctx.video_transformer:
327
+ webrtc_ctx.video_transformer.update_settings(
328
+ iou_threshold=iou_threshold,
329
+ confidence_threshold=confidence_threshold,
330
+ show_detections=show_detections,
331
+ show_summary=show_summary
332
+ )
333
+ return "Settings updated"
334
+
335
+ def get_webrtc_summary(webrtc_ctx):
336
+ if webrtc_ctx and webrtc_ctx.video_transformer:
337
+ return webrtc_ctx.video_transformer.get_summary()
338
+ return "WebRTC not active"
 
 
 
339
 
340
  # Gradio Interface
341
  with gr.Blocks(title="Screw Detection and Measurement") as demo:
 
379
  outputs=[video_output, video_summary]
380
  )
381
 
382
+ with gr.Tab("WebRTC Webcam"):
383
  with gr.Row():
384
+ with gr.Column(scale=1):
385
  webcam_iou = gr.Slider(label="IoU Threshold (NMS)", minimum=0.0, maximum=1.0, value=0.7, step=0.05)
386
  webcam_conf = gr.Slider(label="Confidence Threshold", minimum=0.0, maximum=1.0, value=0.5, step=0.05)
387
  webcam_show_det = gr.Checkbox(label="Show Detections", value=True)
388
  webcam_show_sum = gr.Checkbox(label="Show Summary", value=True)
389
+
390
+ # Create a settings update button
391
+ update_settings = gr.Button("Update Settings")
392
+
393
+ # Summary textbox
394
+ webcam_summary = gr.Textbox(label="Detection Summary", interactive=False)
395
+
396
+ # Button to get summary
397
+ get_summary = gr.Button("Get Detection Summary")
398
+
399
+ with gr.Column(scale=2):
400
+ # Configure WebRTC with STUN servers
401
+ rtc_config = RTCConfiguration(
402
+ {"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]}
403
+ )
404
+
405
+ # Create the WebRTC component with our transformer
406
+ webrtc_ctx = gr.State(None)
407
+
408
+ # Use WebRtcStreamer with our transformer
409
+ webrtc = WebRtcStreamer(
410
+ key="screw-detection",
411
+ mode=WebRtcMode.SENDRECV,
412
+ rtc_configuration=rtc_config,
413
+ video_transformer_factory=ScrewDetectionTransformer,
414
+ async_transform=True,
415
+ )
416
 
417
+ # Connect the update settings button
418
+ update_settings.click(
419
+ update_webrtc_settings,
420
+ inputs=[webcam_iou, webcam_conf, webcam_show_det, webcam_show_sum, webrtc_ctx],
421
+ outputs=gr.Textbox(value="Settings updated", visible=False)
422
+ )
423
+
424
+ # Connect the get summary button
425
+ get_summary.click(
426
+ get_webrtc_summary,
427
+ inputs=[webrtc_ctx],
428
+ outputs=webcam_summary
429
  )
430
 
431
  # Add warning about model loading
432
  if model is None:
433
  gr.Warning("Model could not be loaded. Please ensure 'yolo11-obb12classes.pt' is available.")
434
 
435
+ demo.launch()