Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import numpy as np | |
| from PIL import Image, ImageDraw, ImageFont | |
| from collections import Counter | |
| import time | |
| import os | |
| from ultralytics import YOLO | |
| import cv2 | |
| from gradio_client.documentation import document, DocumentedType | |
| # Import WebRTC components | |
| from gradio_webrtc import ( | |
| RTCConfiguration, | |
| WebRtcStreamerContext, | |
| WebRtcMode, | |
| WebRtcStreamer, | |
| VideoTransformerBase, | |
| VideoTransformerContext, | |
| ) | |
| # Constants | |
| COIN_CLASS_ID = 11 # 10sen coin | |
| COIN_DIAMETER_MM = 18.80 # 10sen coin diameter in mm | |
| CLASS_NAMES = { | |
| 0: 'long lag screw', | |
| 1: 'wood screw', | |
| 2: 'lag wood screw', | |
| 3: 'short wood screw', | |
| 4: 'shiny screw', | |
| 5: 'black oxide screw', | |
| 6: 'nut', | |
| 7: 'bolt', | |
| 8: 'large nut', | |
| 9: 'machine screw', | |
| 10: 'short machine screw', | |
| 11: '10sen Coin' | |
| } | |
| CATEGORY_COLORS = { | |
| 'long lag screw': (255, 0, 0), | |
| 'wood screw': (0, 255, 0), | |
| 'lag wood screw': (0, 0, 255), | |
| 'short wood screw': (255, 255, 0), | |
| 'shiny screw': (255, 0, 255), | |
| 'black oxide screw': (0, 255, 255), | |
| 'nut': (128, 0, 128), | |
| 'bolt': (255, 165, 0), | |
| 'large nut': (128, 128, 0), | |
| 'machine screw': (0, 128, 128), | |
| 'short machine screw': (128, 0, 0), | |
| '10sen Coin': (192, 192, 192) | |
| } | |
| LABEL_FONT_SIZE = 20 | |
| BORDER_WIDTH = 3 | |
| # Load YOLO model - add a progress indicator | |
| print("Loading YOLO model...") | |
| # Check if the model file exists first | |
| if not os.path.exists("yolo11-obb12classes.pt"): | |
| print("Model file not found! Please upload the model file to your Huggingface Space.") | |
| try: | |
| model = YOLO("yolo11-obb12classes.pt") | |
| print("Model loaded successfully!") | |
| except Exception as e: | |
| print(f"Error loading YOLO model: {e}") | |
| model = None | |
| def get_text_size(draw, text, font): | |
| if hasattr(draw, 'textbbox'): | |
| bbox = draw.textbbox((0, 0), text, font=font) | |
| return bbox[2] - bbox[0], bbox[3] - bbox[1] | |
| else: | |
| return draw.textsize(text, font=font) | |
| def non_max_suppression(detections, iou_threshold): | |
| """Improved NMS for OBB that keeps multiple non-overlapping boxes""" | |
| if len(detections) == 0: | |
| return [] | |
| boxes = [] | |
| scores = [] | |
| classes = [] | |
| for det in detections: | |
| if len(det.xyxy) > 0: | |
| boxes.append(det.xyxy[0].cpu().numpy()) | |
| scores.append(det.conf[0].cpu().numpy()) | |
| classes.append(det.cls[0].cpu().numpy()) | |
| if not boxes: | |
| return [] | |
| boxes = np.array(boxes) | |
| scores = np.array(scores) | |
| classes = np.array(classes) | |
| indices = np.argsort(scores)[::-1] | |
| keep_indices = [] | |
| while len(indices) > 0: | |
| current = indices[0] | |
| keep_indices.append(current) | |
| rest = indices[1:] | |
| ious = [] | |
| for i in rest: | |
| box1 = boxes[current] | |
| box2 = boxes[i] | |
| xA = max(box1[0], box2[0]) | |
| yA = max(box1[1], box2[1]) | |
| xB = min(box1[2], box2[2]) | |
| yB = min(box1[3], box2[3]) | |
| interArea = max(0, xB - xA) * max(0, yB - yA) | |
| box1Area = (box1[2] - box1[0]) * (box1[3] - box1[1]) | |
| box2Area = (box2[2] - box2[0]) * (box2[3] - box2[1]) | |
| unionArea = box1Area + box2Area - interArea | |
| iou = interArea / unionArea if unionArea > 0 else 0.0 | |
| ious.append(iou) | |
| ious = np.array(ious) | |
| same_class = (classes[rest] == classes[current]) | |
| to_keep = ~(same_class & (ious > iou_threshold)) | |
| indices = rest[to_keep] | |
| return [detections[i] for i in keep_indices] | |
| class ScrewDetectionProcessor: | |
| def __init__(self): | |
| self.px_to_mm_ratio = None | |
| self.detected_objects = [] | |
| self.show_detections = True | |
| self.show_summary = True | |
| self.iou_threshold = 0.7 | |
| self.confidence_threshold = 0.5 | |
| def update_settings(self, iou_threshold, confidence_threshold, show_detections, show_summary): | |
| self.iou_threshold = iou_threshold | |
| self.confidence_threshold = confidence_threshold | |
| self.show_detections = show_detections | |
| self.show_summary = show_summary | |
| def get_summary(self): | |
| if not self.show_summary or not self.detected_objects: | |
| return "No screws or nuts detected yet." | |
| screw_counts = Counter(self.detected_objects) | |
| summary_text = "Detection Summary:\n" | |
| for name, count in screw_counts.items(): | |
| summary_text += f"- {name}: {count}\n" | |
| return summary_text | |
| def process_frame(self, frame): | |
| if model is None: | |
| return frame, [] | |
| # Ensure frame is in correct format | |
| if isinstance(frame, np.ndarray): | |
| frame_np = frame | |
| else: | |
| # This handles the case if frame comes from other sources | |
| frame_np = np.array(frame) | |
| results = model(frame_np, conf=self.confidence_threshold) | |
| if not results or len(results) == 0: | |
| return frame_np, [] | |
| result = results[0] | |
| filtered_detections = non_max_suppression(result.obb, self.iou_threshold) | |
| pil_image = Image.fromarray(cv2.cvtColor(frame_np, cv2.COLOR_BGR2RGB)) | |
| draw = ImageDraw.Draw(pil_image) | |
| try: | |
| # Use a system font that should be available on most platforms | |
| font = ImageFont.truetype("DejaVuSans.ttf", LABEL_FONT_SIZE) | |
| except: | |
| try: | |
| font = ImageFont.truetype("Arial.ttf", LABEL_FONT_SIZE) | |
| except: | |
| font = ImageFont.load_default() | |
| if hasattr(font, 'size'): | |
| font.size = LABEL_FONT_SIZE | |
| frame_detected_objects = [] | |
| # Find coin for scaling | |
| if self.px_to_mm_ratio is None: | |
| for detection in filtered_detections: | |
| if len(detection.cls) > 0 and int(detection.cls[0]) == COIN_CLASS_ID and len(detection.xywhr) > 0: | |
| coin_xywhr = detection.xywhr[0] | |
| width_px = coin_xywhr[2] | |
| height_px = coin_xywhr[3] | |
| avg_px_diameter = (width_px + height_px) / 2 | |
| if avg_px_diameter > 0: | |
| self.px_to_mm_ratio = COIN_DIAMETER_MM / avg_px_diameter | |
| break | |
| # Draw detections | |
| for detection in filtered_detections: | |
| if len(detection.cls) > 0 and len(detection.xywhr) > 0 and len(detection.xyxy) > 0: | |
| class_id = int(detection.cls[0]) | |
| x1, y1, x2, y2 = map(int, detection.xyxy[0]) | |
| class_name = CLASS_NAMES.get(class_id, f"Class {int(class_id)}") | |
| color = CATEGORY_COLORS.get(class_name, (0, 255, 0)) | |
| label_text = f"{class_name}" | |
| if class_id != COIN_CLASS_ID: | |
| frame_detected_objects.append(class_name) | |
| if class_id == COIN_CLASS_ID and self.px_to_mm_ratio: | |
| diameter_px = (x2 - x1 + y2 - y1) / 2 | |
| diameter_mm = diameter_px * self.px_to_mm_ratio | |
| label_text += f", Dia: {diameter_mm:.2f}mm" | |
| elif class_id != COIN_CLASS_ID and self.px_to_mm_ratio: | |
| xywhr = detection.xywhr[0] | |
| width_px = xywhr[2] | |
| height_px = xywhr[3] | |
| length_px = max(width_px, height_px) | |
| length_mm = length_px * self.px_to_mm_ratio | |
| label_text += f", Length: {length_mm:.2f}mm" | |
| elif class_id != COIN_CLASS_ID: | |
| label_text += ", Length: N/A (No Coin)" | |
| elif class_id == COIN_CLASS_ID: | |
| label_text += ", Dia: N/A (No Ratio)" | |
| if self.show_detections: | |
| draw.rectangle([(x1, y1), (x2, y2)], outline=color, width=BORDER_WIDTH) | |
| text_width, text_height = get_text_size(draw, label_text, font) | |
| draw.rectangle([(x1, y1 - text_height - 5), (x1 + text_width + 5, y1)], fill=color) | |
| draw.text((x1 + 2, y1 - text_height - 3), label_text, fill=(255, 255, 255), font=font) | |
| self.detected_objects.extend(frame_detected_objects) | |
| processed_img = np.array(pil_image) | |
| # Convert back to BGR for OpenCV operations | |
| return cv2.cvtColor(processed_img, cv2.COLOR_RGB2BGR), frame_detected_objects | |
| # WebRTC Video Transformer | |
| class ScrewDetectionTransformer(VideoTransformerBase): | |
| def __init__(self): | |
| self.processor = ScrewDetectionProcessor() | |
| self.summary_text = "No detections yet." | |
| def update_settings(self, iou_threshold, confidence_threshold, show_detections, show_summary): | |
| self.processor.update_settings( | |
| iou_threshold=iou_threshold, | |
| confidence_threshold=confidence_threshold, | |
| show_detections=show_detections, | |
| show_summary=show_summary | |
| ) | |
| def get_summary(self): | |
| return self.processor.get_summary() | |
| def transform(self, frame): | |
| # Process frame will be called on each video frame | |
| img = frame.to_ndarray(format="bgr24") | |
| processed_frame, _ = self.processor.process_frame(img) | |
| self.summary_text = self.processor.get_summary() | |
| return processed_frame | |
| def process_image(input_image, iou_threshold, confidence_threshold, show_detections, show_summary): | |
| if input_image is None: | |
| return None, "Please upload an image first." | |
| # Convert PIL to numpy array if needed | |
| if not isinstance(input_image, np.ndarray): | |
| frame = np.array(input_image) | |
| else: | |
| frame = input_image | |
| # Create a temporary processor for image processing | |
| processor = ScrewDetectionProcessor() | |
| processor.update_settings(iou_threshold, confidence_threshold, show_detections, show_summary) | |
| processed_frame, _ = processor.process_frame(frame) | |
| # Convert BGR to RGB for display in Gradio | |
| processed_frame_rgb = cv2.cvtColor(processed_frame, cv2.COLOR_BGR2RGB) | |
| summary = processor.get_summary() | |
| return processed_frame_rgb, summary | |
| def process_video(video_path, iou_threshold, confidence_threshold, show_detections, show_summary): | |
| if video_path is None: | |
| return [], "Please upload a video first." | |
| try: | |
| # Create a processor for video processing | |
| processor = ScrewDetectionProcessor() | |
| processor.update_settings(iou_threshold, confidence_threshold, show_detections, show_summary) | |
| cap = cv2.VideoCapture(video_path) | |
| if not cap.isOpened(): | |
| return [], "Error: Could not open video file." | |
| frames = [] | |
| frame_count = 0 | |
| max_frames = 20 # Limit frames to prevent memory issues | |
| while cap.isOpened() and frame_count < max_frames: | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| processed_frame, _ = processor.process_frame(frame) | |
| # Convert BGR to RGB for display | |
| processed_frame_rgb = cv2.cvtColor(processed_frame, cv2.COLOR_BGR2RGB) | |
| frames.append(processed_frame_rgb) | |
| frame_count += 1 | |
| cap.release() | |
| summary = processor.get_summary() | |
| if not frames: | |
| return [], "No frames could be processed from the video." | |
| return frames, summary | |
| except Exception as e: | |
| return [], f"Error processing video: {str(e)}" | |
| def update_webrtc_settings(iou_threshold, confidence_threshold, show_detections, show_summary, webrtc_ctx): | |
| if webrtc_ctx and webrtc_ctx.video_transformer: | |
| webrtc_ctx.video_transformer.update_settings( | |
| iou_threshold=iou_threshold, | |
| confidence_threshold=confidence_threshold, | |
| show_detections=show_detections, | |
| show_summary=show_summary | |
| ) | |
| return "Settings updated" | |
| def get_webrtc_summary(webrtc_ctx): | |
| if webrtc_ctx and webrtc_ctx.video_transformer: | |
| return webrtc_ctx.video_transformer.get_summary() | |
| return "WebRTC not active" | |
| # Gradio Interface | |
| with gr.Blocks(title="Screw Detection and Measurement") as demo: | |
| gr.Markdown("# 🔍 Screw Detection and Measurement (YOLOv11 OBB)") | |
| with gr.Tab("Image"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| image_input = gr.Image(label="Upload Image", type="numpy") | |
| image_iou = gr.Slider(label="IoU Threshold (NMS)", minimum=0.0, maximum=1.0, value=0.7, step=0.05) | |
| image_conf = gr.Slider(label="Confidence Threshold", minimum=0.0, maximum=1.0, value=0.5, step=0.05) | |
| image_show_det = gr.Checkbox(label="Show Detections", value=True) | |
| image_show_sum = gr.Checkbox(label="Show Summary", value=True) | |
| image_button = gr.Button("Process Image") | |
| with gr.Column(): | |
| image_output = gr.Image(label="Processed Image") | |
| image_summary = gr.Textbox(label="Summary", interactive=False) | |
| image_button.click( | |
| process_image, | |
| inputs=[image_input, image_iou, image_conf, image_show_det, image_show_sum], | |
| outputs=[image_output, image_summary] | |
| ) | |
| with gr.Tab("Video"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| video_input = gr.Video(label="Upload Video") | |
| video_iou = gr.Slider(label="IoU Threshold (NMS)", minimum=0.0, maximum=1.0, value=0.7, step=0.05) | |
| video_conf = gr.Slider(label="Confidence Threshold", minimum=0.0, maximum=1.0, value=0.5, step=0.05) | |
| video_show_det = gr.Checkbox(label="Show Detections", value=True) | |
| video_show_sum = gr.Checkbox(label="Show Summary", value=True) | |
| video_button = gr.Button("Process Video") | |
| with gr.Column(): | |
| video_output = gr.Gallery(label="Processed Frames") | |
| video_summary = gr.Textbox(label="Summary", interactive=False) | |
| video_button.click( | |
| process_video, | |
| inputs=[video_input, video_iou, video_conf, video_show_det, video_show_sum], | |
| outputs=[video_output, video_summary] | |
| ) | |
| with gr.Tab("WebRTC Webcam"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| webcam_iou = gr.Slider(label="IoU Threshold (NMS)", minimum=0.0, maximum=1.0, value=0.7, step=0.05) | |
| webcam_conf = gr.Slider(label="Confidence Threshold", minimum=0.0, maximum=1.0, value=0.5, step=0.05) | |
| webcam_show_det = gr.Checkbox(label="Show Detections", value=True) | |
| webcam_show_sum = gr.Checkbox(label="Show Summary", value=True) | |
| # Create a settings update button | |
| update_settings = gr.Button("Update Settings") | |
| # Summary textbox | |
| webcam_summary = gr.Textbox(label="Detection Summary", interactive=False) | |
| # Button to get summary | |
| get_summary = gr.Button("Get Detection Summary") | |
| with gr.Column(scale=2): | |
| # Configure WebRTC with STUN servers | |
| rtc_config = RTCConfiguration( | |
| {"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]} | |
| ) | |
| # Create the WebRTC component with our transformer | |
| webrtc_ctx = gr.State(None) | |
| # Use WebRtcStreamer with our transformer | |
| webrtc = WebRtcStreamer( | |
| key="screw-detection", | |
| mode=WebRtcMode.SENDRECV, | |
| rtc_configuration=rtc_config, | |
| video_transformer_factory=ScrewDetectionTransformer, | |
| async_transform=True, | |
| ) | |
| # Connect the update settings button | |
| update_settings.click( | |
| update_webrtc_settings, | |
| inputs=[webcam_iou, webcam_conf, webcam_show_det, webcam_show_sum, webrtc_ctx], | |
| outputs=gr.Textbox(value="Settings updated", visible=False) | |
| ) | |
| # Connect the get summary button | |
| get_summary.click( | |
| get_webrtc_summary, | |
| inputs=[webrtc_ctx], | |
| outputs=webcam_summary | |
| ) | |
| # Add warning about model loading | |
| if model is None: | |
| gr.Warning("Model could not be loaded. Please ensure 'yolo11-obb12classes.pt' is available.") | |
| demo.launch() | |