import gradio as gr import numpy as np from PIL import Image, ImageDraw, ImageFont from collections import Counter import time import os from ultralytics import YOLO import cv2 from gradio_client.documentation import document, DocumentedType # Import WebRTC components from gradio_webrtc import ( RTCConfiguration, WebRtcStreamerContext, WebRtcMode, WebRtcStreamer, VideoTransformerBase, VideoTransformerContext, ) # Constants COIN_CLASS_ID = 11 # 10sen coin COIN_DIAMETER_MM = 18.80 # 10sen coin diameter in mm CLASS_NAMES = { 0: 'long lag screw', 1: 'wood screw', 2: 'lag wood screw', 3: 'short wood screw', 4: 'shiny screw', 5: 'black oxide screw', 6: 'nut', 7: 'bolt', 8: 'large nut', 9: 'machine screw', 10: 'short machine screw', 11: '10sen Coin' } CATEGORY_COLORS = { 'long lag screw': (255, 0, 0), 'wood screw': (0, 255, 0), 'lag wood screw': (0, 0, 255), 'short wood screw': (255, 255, 0), 'shiny screw': (255, 0, 255), 'black oxide screw': (0, 255, 255), 'nut': (128, 0, 128), 'bolt': (255, 165, 0), 'large nut': (128, 128, 0), 'machine screw': (0, 128, 128), 'short machine screw': (128, 0, 0), '10sen Coin': (192, 192, 192) } LABEL_FONT_SIZE = 20 BORDER_WIDTH = 3 # Load YOLO model - add a progress indicator print("Loading YOLO model...") # Check if the model file exists first if not os.path.exists("yolo11-obb12classes.pt"): print("Model file not found! Please upload the model file to your Huggingface Space.") try: model = YOLO("yolo11-obb12classes.pt") print("Model loaded successfully!") except Exception as e: print(f"Error loading YOLO model: {e}") model = None def get_text_size(draw, text, font): if hasattr(draw, 'textbbox'): bbox = draw.textbbox((0, 0), text, font=font) return bbox[2] - bbox[0], bbox[3] - bbox[1] else: return draw.textsize(text, font=font) def non_max_suppression(detections, iou_threshold): """Improved NMS for OBB that keeps multiple non-overlapping boxes""" if len(detections) == 0: return [] boxes = [] scores = [] classes = [] for det in detections: if len(det.xyxy) > 0: boxes.append(det.xyxy[0].cpu().numpy()) scores.append(det.conf[0].cpu().numpy()) classes.append(det.cls[0].cpu().numpy()) if not boxes: return [] boxes = np.array(boxes) scores = np.array(scores) classes = np.array(classes) indices = np.argsort(scores)[::-1] keep_indices = [] while len(indices) > 0: current = indices[0] keep_indices.append(current) rest = indices[1:] ious = [] for i in rest: box1 = boxes[current] box2 = boxes[i] xA = max(box1[0], box2[0]) yA = max(box1[1], box2[1]) xB = min(box1[2], box2[2]) yB = min(box1[3], box2[3]) interArea = max(0, xB - xA) * max(0, yB - yA) box1Area = (box1[2] - box1[0]) * (box1[3] - box1[1]) box2Area = (box2[2] - box2[0]) * (box2[3] - box2[1]) unionArea = box1Area + box2Area - interArea iou = interArea / unionArea if unionArea > 0 else 0.0 ious.append(iou) ious = np.array(ious) same_class = (classes[rest] == classes[current]) to_keep = ~(same_class & (ious > iou_threshold)) indices = rest[to_keep] return [detections[i] for i in keep_indices] class ScrewDetectionProcessor: def __init__(self): self.px_to_mm_ratio = None self.detected_objects = [] self.show_detections = True self.show_summary = True self.iou_threshold = 0.7 self.confidence_threshold = 0.5 def update_settings(self, iou_threshold, confidence_threshold, show_detections, show_summary): self.iou_threshold = iou_threshold self.confidence_threshold = confidence_threshold self.show_detections = show_detections self.show_summary = show_summary def get_summary(self): if not self.show_summary or not self.detected_objects: return "No screws or nuts detected yet." screw_counts = Counter(self.detected_objects) summary_text = "Detection Summary:\n" for name, count in screw_counts.items(): summary_text += f"- {name}: {count}\n" return summary_text def process_frame(self, frame): if model is None: return frame, [] # Ensure frame is in correct format if isinstance(frame, np.ndarray): frame_np = frame else: # This handles the case if frame comes from other sources frame_np = np.array(frame) results = model(frame_np, conf=self.confidence_threshold) if not results or len(results) == 0: return frame_np, [] result = results[0] filtered_detections = non_max_suppression(result.obb, self.iou_threshold) pil_image = Image.fromarray(cv2.cvtColor(frame_np, cv2.COLOR_BGR2RGB)) draw = ImageDraw.Draw(pil_image) try: # Use a system font that should be available on most platforms font = ImageFont.truetype("DejaVuSans.ttf", LABEL_FONT_SIZE) except: try: font = ImageFont.truetype("Arial.ttf", LABEL_FONT_SIZE) except: font = ImageFont.load_default() if hasattr(font, 'size'): font.size = LABEL_FONT_SIZE frame_detected_objects = [] # Find coin for scaling if self.px_to_mm_ratio is None: for detection in filtered_detections: if len(detection.cls) > 0 and int(detection.cls[0]) == COIN_CLASS_ID and len(detection.xywhr) > 0: coin_xywhr = detection.xywhr[0] width_px = coin_xywhr[2] height_px = coin_xywhr[3] avg_px_diameter = (width_px + height_px) / 2 if avg_px_diameter > 0: self.px_to_mm_ratio = COIN_DIAMETER_MM / avg_px_diameter break # Draw detections for detection in filtered_detections: if len(detection.cls) > 0 and len(detection.xywhr) > 0 and len(detection.xyxy) > 0: class_id = int(detection.cls[0]) x1, y1, x2, y2 = map(int, detection.xyxy[0]) class_name = CLASS_NAMES.get(class_id, f"Class {int(class_id)}") color = CATEGORY_COLORS.get(class_name, (0, 255, 0)) label_text = f"{class_name}" if class_id != COIN_CLASS_ID: frame_detected_objects.append(class_name) if class_id == COIN_CLASS_ID and self.px_to_mm_ratio: diameter_px = (x2 - x1 + y2 - y1) / 2 diameter_mm = diameter_px * self.px_to_mm_ratio label_text += f", Dia: {diameter_mm:.2f}mm" elif class_id != COIN_CLASS_ID and self.px_to_mm_ratio: xywhr = detection.xywhr[0] width_px = xywhr[2] height_px = xywhr[3] length_px = max(width_px, height_px) length_mm = length_px * self.px_to_mm_ratio label_text += f", Length: {length_mm:.2f}mm" elif class_id != COIN_CLASS_ID: label_text += ", Length: N/A (No Coin)" elif class_id == COIN_CLASS_ID: label_text += ", Dia: N/A (No Ratio)" if self.show_detections: draw.rectangle([(x1, y1), (x2, y2)], outline=color, width=BORDER_WIDTH) text_width, text_height = get_text_size(draw, label_text, font) draw.rectangle([(x1, y1 - text_height - 5), (x1 + text_width + 5, y1)], fill=color) draw.text((x1 + 2, y1 - text_height - 3), label_text, fill=(255, 255, 255), font=font) self.detected_objects.extend(frame_detected_objects) processed_img = np.array(pil_image) # Convert back to BGR for OpenCV operations return cv2.cvtColor(processed_img, cv2.COLOR_RGB2BGR), frame_detected_objects # WebRTC Video Transformer class ScrewDetectionTransformer(VideoTransformerBase): def __init__(self): self.processor = ScrewDetectionProcessor() self.summary_text = "No detections yet." def update_settings(self, iou_threshold, confidence_threshold, show_detections, show_summary): self.processor.update_settings( iou_threshold=iou_threshold, confidence_threshold=confidence_threshold, show_detections=show_detections, show_summary=show_summary ) def get_summary(self): return self.processor.get_summary() def transform(self, frame): # Process frame will be called on each video frame img = frame.to_ndarray(format="bgr24") processed_frame, _ = self.processor.process_frame(img) self.summary_text = self.processor.get_summary() return processed_frame def process_image(input_image, iou_threshold, confidence_threshold, show_detections, show_summary): if input_image is None: return None, "Please upload an image first." # Convert PIL to numpy array if needed if not isinstance(input_image, np.ndarray): frame = np.array(input_image) else: frame = input_image # Create a temporary processor for image processing processor = ScrewDetectionProcessor() processor.update_settings(iou_threshold, confidence_threshold, show_detections, show_summary) processed_frame, _ = processor.process_frame(frame) # Convert BGR to RGB for display in Gradio processed_frame_rgb = cv2.cvtColor(processed_frame, cv2.COLOR_BGR2RGB) summary = processor.get_summary() return processed_frame_rgb, summary def process_video(video_path, iou_threshold, confidence_threshold, show_detections, show_summary): if video_path is None: return [], "Please upload a video first." try: # Create a processor for video processing processor = ScrewDetectionProcessor() processor.update_settings(iou_threshold, confidence_threshold, show_detections, show_summary) cap = cv2.VideoCapture(video_path) if not cap.isOpened(): return [], "Error: Could not open video file." frames = [] frame_count = 0 max_frames = 20 # Limit frames to prevent memory issues while cap.isOpened() and frame_count < max_frames: ret, frame = cap.read() if not ret: break processed_frame, _ = processor.process_frame(frame) # Convert BGR to RGB for display processed_frame_rgb = cv2.cvtColor(processed_frame, cv2.COLOR_BGR2RGB) frames.append(processed_frame_rgb) frame_count += 1 cap.release() summary = processor.get_summary() if not frames: return [], "No frames could be processed from the video." return frames, summary except Exception as e: return [], f"Error processing video: {str(e)}" def update_webrtc_settings(iou_threshold, confidence_threshold, show_detections, show_summary, webrtc_ctx): if webrtc_ctx and webrtc_ctx.video_transformer: webrtc_ctx.video_transformer.update_settings( iou_threshold=iou_threshold, confidence_threshold=confidence_threshold, show_detections=show_detections, show_summary=show_summary ) return "Settings updated" def get_webrtc_summary(webrtc_ctx): if webrtc_ctx and webrtc_ctx.video_transformer: return webrtc_ctx.video_transformer.get_summary() return "WebRTC not active" # Gradio Interface with gr.Blocks(title="Screw Detection and Measurement") as demo: gr.Markdown("# 🔍 Screw Detection and Measurement (YOLOv11 OBB)") with gr.Tab("Image"): with gr.Row(): with gr.Column(): image_input = gr.Image(label="Upload Image", type="numpy") image_iou = gr.Slider(label="IoU Threshold (NMS)", minimum=0.0, maximum=1.0, value=0.7, step=0.05) image_conf = gr.Slider(label="Confidence Threshold", minimum=0.0, maximum=1.0, value=0.5, step=0.05) image_show_det = gr.Checkbox(label="Show Detections", value=True) image_show_sum = gr.Checkbox(label="Show Summary", value=True) image_button = gr.Button("Process Image") with gr.Column(): image_output = gr.Image(label="Processed Image") image_summary = gr.Textbox(label="Summary", interactive=False) image_button.click( process_image, inputs=[image_input, image_iou, image_conf, image_show_det, image_show_sum], outputs=[image_output, image_summary] ) with gr.Tab("Video"): with gr.Row(): with gr.Column(): video_input = gr.Video(label="Upload Video") video_iou = gr.Slider(label="IoU Threshold (NMS)", minimum=0.0, maximum=1.0, value=0.7, step=0.05) video_conf = gr.Slider(label="Confidence Threshold", minimum=0.0, maximum=1.0, value=0.5, step=0.05) video_show_det = gr.Checkbox(label="Show Detections", value=True) video_show_sum = gr.Checkbox(label="Show Summary", value=True) video_button = gr.Button("Process Video") with gr.Column(): video_output = gr.Gallery(label="Processed Frames") video_summary = gr.Textbox(label="Summary", interactive=False) video_button.click( process_video, inputs=[video_input, video_iou, video_conf, video_show_det, video_show_sum], outputs=[video_output, video_summary] ) with gr.Tab("WebRTC Webcam"): with gr.Row(): with gr.Column(scale=1): webcam_iou = gr.Slider(label="IoU Threshold (NMS)", minimum=0.0, maximum=1.0, value=0.7, step=0.05) webcam_conf = gr.Slider(label="Confidence Threshold", minimum=0.0, maximum=1.0, value=0.5, step=0.05) webcam_show_det = gr.Checkbox(label="Show Detections", value=True) webcam_show_sum = gr.Checkbox(label="Show Summary", value=True) # Create a settings update button update_settings = gr.Button("Update Settings") # Summary textbox webcam_summary = gr.Textbox(label="Detection Summary", interactive=False) # Button to get summary get_summary = gr.Button("Get Detection Summary") with gr.Column(scale=2): # Configure WebRTC with STUN servers rtc_config = RTCConfiguration( {"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]} ) # Create the WebRTC component with our transformer webrtc_ctx = gr.State(None) # Use WebRtcStreamer with our transformer webrtc = WebRtcStreamer( key="screw-detection", mode=WebRtcMode.SENDRECV, rtc_configuration=rtc_config, video_transformer_factory=ScrewDetectionTransformer, async_transform=True, ) # Connect the update settings button update_settings.click( update_webrtc_settings, inputs=[webcam_iou, webcam_conf, webcam_show_det, webcam_show_sum, webrtc_ctx], outputs=gr.Textbox(value="Settings updated", visible=False) ) # Connect the get summary button get_summary.click( get_webrtc_summary, inputs=[webrtc_ctx], outputs=webcam_summary ) # Add warning about model loading if model is None: gr.Warning("Model could not be loaded. Please ensure 'yolo11-obb12classes.pt' is available.") demo.launch()