Spaces:
Sleeping
Sleeping
| import os | |
| os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' | |
| os.environ['GRADIO_DEFAULT_LANG'] = 'en' | |
| import gradio as gr | |
| import torch | |
| import cv2 | |
| import numpy as np | |
| from PIL import Image | |
| from typing import Tuple, List | |
| import json | |
| from rfdetr.detr import RFDETRMedium | |
| # UI Element classes | |
| CLASSES = ['button', 'field', 'heading', 'iframe', 'image', 'label', 'link', 'text'] | |
| # Single color for all boxes (BGR format for OpenCV) | |
| BOX_COLOR = (0, 255, 0) # Green | |
| # Global model variable | |
| model = None | |
| def load_model(model_path: str = "model.pth"): | |
| """Load RF-DETR model""" | |
| global model | |
| if model is None: | |
| print("Loading RF-DETR model...") | |
| model = RFDETRMedium(pretrain_weights=model_path, resolution=1600) | |
| print("Model loaded successfully!") | |
| return model | |
| def draw_detections( | |
| image: np.ndarray, | |
| boxes: List[Tuple[int, int, int, int]], | |
| scores: List[float], | |
| classes: List[int], | |
| thickness: int = 3, | |
| font_scale: float = 0.6 | |
| ) -> np.ndarray: | |
| """Draw detection boxes and labels on image""" | |
| img_with_boxes = image.copy() | |
| for box, score, cls_id in zip(boxes, scores, classes): | |
| x1, y1, x2, y2 = map(int, box) | |
| # Draw rectangle | |
| cv2.rectangle(img_with_boxes, (x1, y1), (x2, y2), BOX_COLOR, thickness) | |
| # Prepare label with confidence score only | |
| label = f"{score:.2f}" | |
| # Calculate label size and position | |
| (label_width, label_height), baseline = cv2.getTextSize( | |
| label, cv2.FONT_HERSHEY_SIMPLEX, font_scale, thickness=2 | |
| ) | |
| # Draw label background | |
| label_y = max(y1 - 10, label_height + 10) | |
| cv2.rectangle( | |
| img_with_boxes, | |
| (x1, label_y - label_height - baseline - 5), | |
| (x1 + label_width + 5, label_y + baseline - 5), | |
| BOX_COLOR, | |
| -1 | |
| ) | |
| # Draw label text | |
| cv2.putText( | |
| img_with_boxes, | |
| label, | |
| (x1 + 2, label_y - baseline - 5), | |
| cv2.FONT_HERSHEY_SIMPLEX, | |
| font_scale, | |
| (255, 255, 255), | |
| thickness=2 | |
| ) | |
| return img_with_boxes | |
| def detections_to_raw_json(detections) -> str: | |
| out = [] | |
| for box, score, cls_id in zip( | |
| detections.xyxy, | |
| detections.confidence, | |
| detections.class_id | |
| ): | |
| cid = int(cls_id) | |
| out.append({ | |
| "class_id": cid, | |
| "class_name": CLASSES[cid] if 0 <= cid < len(CLASSES) else str(cid), | |
| "score": float(score), | |
| "box_xyxy": [ | |
| float(box[0]), | |
| float(box[1]), | |
| float(box[2]), | |
| float(box[3]), | |
| ], | |
| }) | |
| return json.dumps(out, indent=2) | |
| def detect_ui_elements( | |
| image: Image.Image, | |
| confidence_threshold: float, | |
| line_thickness: int | |
| ) -> Tuple[Image.Image, str, str]: | |
| """ | |
| Detect UI elements in the uploaded image | |
| Args: | |
| image: Input PIL Image | |
| confidence_threshold: Minimum confidence score for detections | |
| line_thickness: Thickness of bounding box lines | |
| Returns: | |
| Annotated image and detection summary text | |
| """ | |
| try: | |
| if image is None: | |
| return None, "Please upload an image first.", "[]" | |
| # Load model | |
| model = load_model() | |
| # Convert PIL to numpy array (RGB) | |
| img_array = np.array(image) | |
| # Convert RGB to BGR for OpenCV | |
| img_bgr = cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR) | |
| # Run detection (returns supervision Detections object) | |
| detections = model.predict(img_array, threshold=confidence_threshold) | |
| # Extract detection data | |
| filtered_boxes = detections.xyxy # Bounding boxes in xyxy format | |
| filtered_scores = detections.confidence # Confidence scores | |
| filtered_classes = detections.class_id # Class IDs | |
| # Draw detections | |
| annotated_img = draw_detections( | |
| img_bgr, | |
| filtered_boxes.tolist(), | |
| filtered_scores.tolist(), | |
| filtered_classes.tolist(), | |
| thickness=line_thickness | |
| ) | |
| # Convert back to RGB for display | |
| annotated_img_rgb = cv2.cvtColor(annotated_img, cv2.COLOR_BGR2RGB) | |
| annotated_pil = Image.fromarray(annotated_img_rgb) | |
| # Create summary text | |
| summary_text = f"**Total detections:** {len(filtered_boxes)}" | |
| raw_json = detections_to_raw_json(detections) | |
| return annotated_pil, summary_text, raw_json | |
| except Exception as e: | |
| import traceback | |
| error_msg = f"**Error during detection:**\n\n```\n{str(e)}\n\n{traceback.format_exc()}\n```" | |
| print(error_msg) # Also print to logs | |
| return None, error_msg, "[]" | |
| # Gradio interface | |
| with gr.Blocks(title="UI-DETR-1 UI Element Detector", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown(""" | |
| # UI-DETR-1 UI Element Detector | |
| Upload a screenshot or UI mockup to automatically detect elements. | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| input_image = gr.Image( | |
| type="pil", | |
| label="Upload Screenshot", | |
| height=400, | |
| sources=["upload"] | |
| ) | |
| with gr.Accordion("Detection Settings", open=True): | |
| confidence_slider = gr.Slider( | |
| minimum=0.1, | |
| maximum=0.9, | |
| value=0.35, | |
| step=0.05, | |
| label="Confidence Threshold", | |
| info="Higher values = fewer but more confident detections" | |
| ) | |
| thickness_slider = gr.Slider( | |
| minimum=1, | |
| maximum=6, | |
| value=2, | |
| step=1, | |
| label="Box Line Thickness" | |
| ) | |
| detect_button = gr.Button("Detect Elements", variant="primary", size="lg") | |
| with gr.Column(scale=1): | |
| output_image = gr.Image( | |
| type="pil", | |
| label="Detected Elements", | |
| height=400 | |
| ) | |
| summary_output = gr.Markdown(label="Detection Summary") | |
| raw_output = gr.Code( | |
| label="Raw Detection", | |
| language="json" | |
| ) | |
| # Connect button | |
| detect_button.click( | |
| fn=detect_ui_elements, | |
| inputs=[input_image, confidence_slider, thickness_slider], | |
| outputs=[output_image, summary_output, raw_output] | |
| ) | |
| # Launch | |
| if __name__ == "__main__": | |
| demo.queue().launch(share=False) | |