Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -3,10 +3,20 @@ import numpy as np
|
|
| 3 |
from PIL import Image, ImageDraw, ImageFont
|
| 4 |
from collections import Counter
|
| 5 |
import time
|
| 6 |
-
import
|
| 7 |
from ultralytics import YOLO
|
| 8 |
import cv2
|
| 9 |
-
import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
# Constants
|
| 12 |
COIN_CLASS_ID = 11 # 10sen coin
|
|
@@ -114,7 +124,7 @@ def non_max_suppression(detections, iou_threshold):
|
|
| 114 |
|
| 115 |
return [detections[i] for i in keep_indices]
|
| 116 |
|
| 117 |
-
class
|
| 118 |
def __init__(self):
|
| 119 |
self.px_to_mm_ratio = None
|
| 120 |
self.detected_objects = []
|
|
@@ -147,12 +157,12 @@ class VideoProcessor:
|
|
| 147 |
if isinstance(frame, np.ndarray):
|
| 148 |
frame_np = frame
|
| 149 |
else:
|
| 150 |
-
# This handles the case if frame comes from
|
| 151 |
frame_np = np.array(frame)
|
| 152 |
|
| 153 |
results = model(frame_np, conf=self.confidence_threshold)
|
| 154 |
|
| 155 |
-
if not results:
|
| 156 |
return frame_np, []
|
| 157 |
|
| 158 |
result = results[0]
|
|
@@ -226,6 +236,30 @@ class VideoProcessor:
|
|
| 226 |
# Convert back to BGR for OpenCV operations
|
| 227 |
return cv2.cvtColor(processed_img, cv2.COLOR_RGB2BGR), frame_detected_objects
|
| 228 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 229 |
def process_image(input_image, iou_threshold, confidence_threshold, show_detections, show_summary):
|
| 230 |
if input_image is None:
|
| 231 |
return None, "Please upload an image first."
|
|
@@ -237,7 +271,7 @@ def process_image(input_image, iou_threshold, confidence_threshold, show_detecti
|
|
| 237 |
frame = input_image
|
| 238 |
|
| 239 |
# Create a temporary processor for image processing
|
| 240 |
-
processor =
|
| 241 |
processor.update_settings(iou_threshold, confidence_threshold, show_detections, show_summary)
|
| 242 |
processed_frame, _ = processor.process_frame(frame)
|
| 243 |
|
|
@@ -254,7 +288,7 @@ def process_video(video_path, iou_threshold, confidence_threshold, show_detectio
|
|
| 254 |
|
| 255 |
try:
|
| 256 |
# Create a processor for video processing
|
| 257 |
-
processor =
|
| 258 |
processor.update_settings(iou_threshold, confidence_threshold, show_detections, show_summary)
|
| 259 |
|
| 260 |
cap = cv2.VideoCapture(video_path)
|
|
@@ -288,23 +322,20 @@ def process_video(video_path, iou_threshold, confidence_threshold, show_detectio
|
|
| 288 |
except Exception as e:
|
| 289 |
return [], f"Error processing video: {str(e)}"
|
| 290 |
|
| 291 |
-
def
|
| 292 |
-
if
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
# Convert back to RGB for Gradio
|
| 306 |
-
processed_frame_rgb = cv2.cvtColor(processed_frame, cv2.COLOR_BGR2RGB)
|
| 307 |
-
return processed_frame_rgb
|
| 308 |
|
| 309 |
# Gradio Interface
|
| 310 |
with gr.Blocks(title="Screw Detection and Measurement") as demo:
|
|
@@ -348,28 +379,57 @@ with gr.Blocks(title="Screw Detection and Measurement") as demo:
|
|
| 348 |
outputs=[video_output, video_summary]
|
| 349 |
)
|
| 350 |
|
| 351 |
-
with gr.Tab("Webcam"):
|
| 352 |
with gr.Row():
|
| 353 |
-
with gr.Column():
|
| 354 |
webcam_iou = gr.Slider(label="IoU Threshold (NMS)", minimum=0.0, maximum=1.0, value=0.7, step=0.05)
|
| 355 |
webcam_conf = gr.Slider(label="Confidence Threshold", minimum=0.0, maximum=1.0, value=0.5, step=0.05)
|
| 356 |
webcam_show_det = gr.Checkbox(label="Show Detections", value=True)
|
| 357 |
webcam_show_sum = gr.Checkbox(label="Show Summary", value=True)
|
| 358 |
-
|
| 359 |
-
#
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 363 |
|
| 364 |
-
#
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
inputs=[
|
| 368 |
-
outputs=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 369 |
)
|
| 370 |
|
| 371 |
# Add warning about model loading
|
| 372 |
if model is None:
|
| 373 |
gr.Warning("Model could not be loaded. Please ensure 'yolo11-obb12classes.pt' is available.")
|
| 374 |
|
| 375 |
-
demo.launch()
|
|
|
|
| 3 |
from PIL import Image, ImageDraw, ImageFont
|
| 4 |
from collections import Counter
|
| 5 |
import time
|
| 6 |
+
import os
|
| 7 |
from ultralytics import YOLO
|
| 8 |
import cv2
|
| 9 |
+
from gradio_client.documentation import document, DocumentedType
|
| 10 |
+
|
| 11 |
+
# Import WebRTC components
|
| 12 |
+
from gradio_webrtc import (
|
| 13 |
+
RTCConfiguration,
|
| 14 |
+
WebRtcStreamerContext,
|
| 15 |
+
WebRtcMode,
|
| 16 |
+
WebRtcStreamer,
|
| 17 |
+
VideoTransformerBase,
|
| 18 |
+
VideoTransformerContext,
|
| 19 |
+
)
|
| 20 |
|
| 21 |
# Constants
|
| 22 |
COIN_CLASS_ID = 11 # 10sen coin
|
|
|
|
| 124 |
|
| 125 |
return [detections[i] for i in keep_indices]
|
| 126 |
|
| 127 |
+
class ScrewDetectionProcessor:
|
| 128 |
def __init__(self):
|
| 129 |
self.px_to_mm_ratio = None
|
| 130 |
self.detected_objects = []
|
|
|
|
| 157 |
if isinstance(frame, np.ndarray):
|
| 158 |
frame_np = frame
|
| 159 |
else:
|
| 160 |
+
# This handles the case if frame comes from other sources
|
| 161 |
frame_np = np.array(frame)
|
| 162 |
|
| 163 |
results = model(frame_np, conf=self.confidence_threshold)
|
| 164 |
|
| 165 |
+
if not results or len(results) == 0:
|
| 166 |
return frame_np, []
|
| 167 |
|
| 168 |
result = results[0]
|
|
|
|
| 236 |
# Convert back to BGR for OpenCV operations
|
| 237 |
return cv2.cvtColor(processed_img, cv2.COLOR_RGB2BGR), frame_detected_objects
|
| 238 |
|
| 239 |
+
# WebRTC Video Transformer
|
| 240 |
+
class ScrewDetectionTransformer(VideoTransformerBase):
|
| 241 |
+
def __init__(self):
|
| 242 |
+
self.processor = ScrewDetectionProcessor()
|
| 243 |
+
self.summary_text = "No detections yet."
|
| 244 |
+
|
| 245 |
+
def update_settings(self, iou_threshold, confidence_threshold, show_detections, show_summary):
|
| 246 |
+
self.processor.update_settings(
|
| 247 |
+
iou_threshold=iou_threshold,
|
| 248 |
+
confidence_threshold=confidence_threshold,
|
| 249 |
+
show_detections=show_detections,
|
| 250 |
+
show_summary=show_summary
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
def get_summary(self):
|
| 254 |
+
return self.processor.get_summary()
|
| 255 |
+
|
| 256 |
+
def transform(self, frame):
|
| 257 |
+
# Process frame will be called on each video frame
|
| 258 |
+
img = frame.to_ndarray(format="bgr24")
|
| 259 |
+
processed_frame, _ = self.processor.process_frame(img)
|
| 260 |
+
self.summary_text = self.processor.get_summary()
|
| 261 |
+
return processed_frame
|
| 262 |
+
|
| 263 |
def process_image(input_image, iou_threshold, confidence_threshold, show_detections, show_summary):
|
| 264 |
if input_image is None:
|
| 265 |
return None, "Please upload an image first."
|
|
|
|
| 271 |
frame = input_image
|
| 272 |
|
| 273 |
# Create a temporary processor for image processing
|
| 274 |
+
processor = ScrewDetectionProcessor()
|
| 275 |
processor.update_settings(iou_threshold, confidence_threshold, show_detections, show_summary)
|
| 276 |
processed_frame, _ = processor.process_frame(frame)
|
| 277 |
|
|
|
|
| 288 |
|
| 289 |
try:
|
| 290 |
# Create a processor for video processing
|
| 291 |
+
processor = ScrewDetectionProcessor()
|
| 292 |
processor.update_settings(iou_threshold, confidence_threshold, show_detections, show_summary)
|
| 293 |
|
| 294 |
cap = cv2.VideoCapture(video_path)
|
|
|
|
| 322 |
except Exception as e:
|
| 323 |
return [], f"Error processing video: {str(e)}"
|
| 324 |
|
| 325 |
+
def update_webrtc_settings(iou_threshold, confidence_threshold, show_detections, show_summary, webrtc_ctx):
|
| 326 |
+
if webrtc_ctx and webrtc_ctx.video_transformer:
|
| 327 |
+
webrtc_ctx.video_transformer.update_settings(
|
| 328 |
+
iou_threshold=iou_threshold,
|
| 329 |
+
confidence_threshold=confidence_threshold,
|
| 330 |
+
show_detections=show_detections,
|
| 331 |
+
show_summary=show_summary
|
| 332 |
+
)
|
| 333 |
+
return "Settings updated"
|
| 334 |
+
|
| 335 |
+
def get_webrtc_summary(webrtc_ctx):
|
| 336 |
+
if webrtc_ctx and webrtc_ctx.video_transformer:
|
| 337 |
+
return webrtc_ctx.video_transformer.get_summary()
|
| 338 |
+
return "WebRTC not active"
|
|
|
|
|
|
|
|
|
|
| 339 |
|
| 340 |
# Gradio Interface
|
| 341 |
with gr.Blocks(title="Screw Detection and Measurement") as demo:
|
|
|
|
| 379 |
outputs=[video_output, video_summary]
|
| 380 |
)
|
| 381 |
|
| 382 |
+
with gr.Tab("WebRTC Webcam"):
|
| 383 |
with gr.Row():
|
| 384 |
+
with gr.Column(scale=1):
|
| 385 |
webcam_iou = gr.Slider(label="IoU Threshold (NMS)", minimum=0.0, maximum=1.0, value=0.7, step=0.05)
|
| 386 |
webcam_conf = gr.Slider(label="Confidence Threshold", minimum=0.0, maximum=1.0, value=0.5, step=0.05)
|
| 387 |
webcam_show_det = gr.Checkbox(label="Show Detections", value=True)
|
| 388 |
webcam_show_sum = gr.Checkbox(label="Show Summary", value=True)
|
| 389 |
+
|
| 390 |
+
# Create a settings update button
|
| 391 |
+
update_settings = gr.Button("Update Settings")
|
| 392 |
+
|
| 393 |
+
# Summary textbox
|
| 394 |
+
webcam_summary = gr.Textbox(label="Detection Summary", interactive=False)
|
| 395 |
+
|
| 396 |
+
# Button to get summary
|
| 397 |
+
get_summary = gr.Button("Get Detection Summary")
|
| 398 |
+
|
| 399 |
+
with gr.Column(scale=2):
|
| 400 |
+
# Configure WebRTC with STUN servers
|
| 401 |
+
rtc_config = RTCConfiguration(
|
| 402 |
+
{"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]}
|
| 403 |
+
)
|
| 404 |
+
|
| 405 |
+
# Create the WebRTC component with our transformer
|
| 406 |
+
webrtc_ctx = gr.State(None)
|
| 407 |
+
|
| 408 |
+
# Use WebRtcStreamer with our transformer
|
| 409 |
+
webrtc = WebRtcStreamer(
|
| 410 |
+
key="screw-detection",
|
| 411 |
+
mode=WebRtcMode.SENDRECV,
|
| 412 |
+
rtc_configuration=rtc_config,
|
| 413 |
+
video_transformer_factory=ScrewDetectionTransformer,
|
| 414 |
+
async_transform=True,
|
| 415 |
+
)
|
| 416 |
|
| 417 |
+
# Connect the update settings button
|
| 418 |
+
update_settings.click(
|
| 419 |
+
update_webrtc_settings,
|
| 420 |
+
inputs=[webcam_iou, webcam_conf, webcam_show_det, webcam_show_sum, webrtc_ctx],
|
| 421 |
+
outputs=gr.Textbox(value="Settings updated", visible=False)
|
| 422 |
+
)
|
| 423 |
+
|
| 424 |
+
# Connect the get summary button
|
| 425 |
+
get_summary.click(
|
| 426 |
+
get_webrtc_summary,
|
| 427 |
+
inputs=[webrtc_ctx],
|
| 428 |
+
outputs=webcam_summary
|
| 429 |
)
|
| 430 |
|
| 431 |
# Add warning about model loading
|
| 432 |
if model is None:
|
| 433 |
gr.Warning("Model could not be loaded. Please ensure 'yolo11-obb12classes.pt' is available.")
|
| 434 |
|
| 435 |
+
demo.launch()
|