""" Detection Service - Core Business Logic This module contains the main DetectionService class that handles UI element detection. ARCHITECTURE: ------------- This service uses a multi-model pipeline: 1. RF-DETR (Detection Transformer) - Detects generic "UI elements" as a SINGLE CLASS - Provides bounding boxes and confidence scores - Does NOT distinguish between button, input, text, etc. 2. CLIP (OpenAI) - OPTIONAL multi-class classification - Takes RF-DETR detections and classifies them into 6 types: * button, input, text, image, list_item, navigation - Only runs if enable_clip=True 3. EasyOCR - Extracts text content from detected regions - Runs global OCR merge to catch text outside detection boxes 4. BLIP (Salesforce) - OPTIONAL visual description generation - Describes icons and images when text is not present - Only runs if enable_blip=True Usage: from detection.service import DetectionService service = DetectionService() results = service.analyze(image_path) """ import os os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' import torch import cv2 import numpy as np from PIL import Image from typing import Union, List, Dict, Tuple, Optional from pathlib import Path from rfdetr.detr import RFDETRMedium import easyocr from transformers import BlipProcessor, BlipForConditionalGeneration, CLIPProcessor, CLIPModel from detection.image_utils import load_image from detection.image_preprocessing import preprocess_screenshot, PRESETS from detection.rfdetr_preprocessing import preprocess_for_rfdetr, RFDETR_PRESETS class DetectionService: """ Detection Service for UI Element Detection Provides a complete pipeline for detecting and analyzing UI elements in screenshots. Uses RF-DETR for detection (single class), CLIP for classification (6 classes), OCR for text extraction, and BLIP for visual descriptions. """ # UI Element classes - Optimized for Mobile Apps # NOTE: These are NOT detected by RF-DETR (single class only) # CLIP classifies RF-DETR detections into these 6 types CLASSES = [ 'button', # Buttons, FAB, chips, switches 'input', # Text fields, search bars 'text', # Labels, titles, paragraphs, descriptions 'image', # Images, icons, avatars, illustrations 'list_item', # List items, cards, tiles 'navigation' # Bottom nav, tabs, app bars, menus ] # Default box color (BGR format for OpenCV) BOX_COLOR = (0, 255, 0) # Green def __init__(self, model_path: str = "model.pth", enable_ocr: bool = True, enable_blip: bool = True, enable_clip: bool = True): """ Initialize the Detection Service Args: model_path: Path to the RF-DETR model weights enable_ocr: Whether to enable OCR for text extraction enable_blip: Whether to enable BLIP for icon description enable_clip: Whether to enable CLIP for UI element classification """ self.model_path = model_path self.enable_ocr = enable_ocr self.enable_blip = enable_blip self.enable_clip = enable_clip self.model = None self.ocr_reader = None self.blip_processor = None self.blip_model = None self.clip_processor = None self.clip_model = None # Load the detection model immediately self._load_detection_model() def _load_detection_model(self): """Load RF-DETR model (single-class UI element detector)""" if self.model is None: print("Loading RF-DETR model...") kwargs = {"pretrain_weights": self.model_path} custom_resolution = os.getenv("RFDETR_RESOLUTION") if custom_resolution: try: kwargs["resolution"] = int(custom_resolution) print(f"Using custom RF-DETR resolution: {kwargs['resolution']}") except ValueError: print(f"Warning: invalid RFDETR_RESOLUTION '{custom_resolution}'. Falling back to model default.") else: kwargs["resolution"] = 1600 # Default tuned for CU-1 deployment self.model = RFDETRMedium(**kwargs) print("RF-DETR model loaded successfully!") def _load_ocr(self): """Load EasyOCR reader for text extraction""" if self.enable_ocr and self.ocr_reader is None: print("Loading OCR reader...") self.ocr_reader = easyocr.Reader(['en', 'fr'], gpu=torch.cuda.is_available()) print("OCR reader loaded successfully!") def _load_blip(self): """Load BLIP model for image captioning""" if self.enable_blip and (self.blip_processor is None or self.blip_model is None): print("Loading BLIP model for icon description...") self.blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") # Use safetensors format to avoid torch.load vulnerability (CVE-2025-32434) self.blip_model = BlipForConditionalGeneration.from_pretrained( "Salesforce/blip-image-captioning-base", use_safetensors=True ) if torch.cuda.is_available(): self.blip_model = self.blip_model.to("cuda") print("BLIP model loaded successfully!") def _load_clip(self): """Load CLIP model for UI element classification""" if self.enable_clip and (self.clip_processor is None or self.clip_model is None): print("Loading CLIP model for UI element classification...") self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") # Use safetensors format to avoid torch.load vulnerability (CVE-2025-32434) self.clip_model = CLIPModel.from_pretrained( "openai/clip-vit-base-patch32", use_safetensors=True ) if torch.cuda.is_available(): self.clip_model = self.clip_model.to("cuda") print("CLIP model loaded successfully!") def _classify_with_clip(self, cropped_img: np.ndarray) -> int: """ Classify UI element using CLIP Args: cropped_img: Cropped numpy array of the UI element Returns: Predicted class_id (0-5 corresponding to CLASSES) """ if cropped_img.size == 0: return 0 # Default to first class if not self.enable_clip: return 0 # No classification, return default self._load_clip() try: # Convert numpy array to PIL Image pil_img = Image.fromarray(cropped_img) # Create text prompts for each class - Optimized for mobile UI text_prompts = [ "a mobile app button or interactive element", "a text input field or search bar in a mobile app", "text label, heading, or paragraph in a mobile app", "an image, icon, or avatar in a mobile app", "a list item, card, or tile in a mobile app", "a navigation bar, tab, or menu in a mobile app" ] # Process with CLIP inputs = self.clip_processor( text=text_prompts, images=pil_img, return_tensors="pt", padding=True ) if torch.cuda.is_available(): inputs = {k: v.to("cuda") for k, v in inputs.items()} # Get predictions outputs = self.clip_model(**inputs) logits_per_image = outputs.logits_per_image probs = logits_per_image.softmax(dim=1) # Get the class with highest probability predicted_class_id = probs.argmax().item() return predicted_class_id except Exception as clip_error: print(f"CLIP classification error: {clip_error}") return 0 # Fallback to default class def _extract_text(self, cropped_img: np.ndarray) -> str: """Extract plain text from a cropped region using OCR (no BLIP).""" if not self.enable_ocr or cropped_img.size == 0: return "" self._load_ocr() try: ocr_results = self.ocr_reader.readtext(cropped_img, detail=0) return " ".join(ocr_results).strip() except Exception as ocr_error: print(f"OCR error: {ocr_error}") return "" def _describe_with_blip(self, cropped_img: np.ndarray) -> str: """Generate a visual description using BLIP for a cropped region.""" if not self.enable_blip or cropped_img.size == 0: return "" self._load_blip() try: pil_img = Image.fromarray(cropped_img) inputs = self.blip_processor(pil_img, return_tensors="pt") if torch.cuda.is_available(): inputs = {k: v.to("cuda") for k, v in inputs.items()} out = self.blip_model.generate(**inputs, max_length=50) return self.blip_processor.decode(out[0], skip_special_tokens=True) except Exception as blip_error: print(f"BLIP error: {blip_error}") return "" @staticmethod def _iou(box_a: Tuple[int, int, int, int], box_b: Tuple[int, int, int, int]) -> float: """Calculate Intersection over Union between two boxes""" xA = max(box_a[0], box_b[0]) yA = max(box_a[1], box_b[1]) xB = min(box_a[2], box_b[2]) yB = min(box_a[3], box_b[3]) inter_w = max(0, xB - xA) inter_h = max(0, yB - yA) inter_area = inter_w * inter_h if inter_area == 0: return 0.0 box_a_area = max(0, (box_a[2] - box_a[0])) * max(0, (box_a[3] - box_a[1])) box_b_area = max(0, (box_b[2] - box_b[0])) * max(0, (box_b[3] - box_b[1])) union = box_a_area + box_b_area - inter_area if union <= 0: return 0.0 return inter_area / union @staticmethod def _box_center(box: Tuple[int, int, int, int]) -> Tuple[float, float]: """Calculate the center point of a bounding box""" x1, y1, x2, y2 = box return (x1 + x2) / 2.0, (y1 + y2) / 2.0 @torch.inference_mode() def analyze( self, image: Union[str, Path, np.ndarray, Image.Image], confidence_threshold: float = 0.35, extract_text: bool = True, use_clip: bool = True, use_blip: bool = False, merge_global_ocr: bool = True, blip_scope: str = "icons", preprocess: bool = False, preprocess_preset: str = "standard", preprocess_mode: str = "rfdetr" ) -> Dict: """ Run a single-pass analysis: detection, optional CLIP classification, OCR, optional BLIP, and optional global OCR merge into nearest detection. PIPELINE: 0. Optional preprocessing (normalize colors, contrast, denoise) 1. RF-DETR detects UI elements (single class - just bounding boxes) 2. CLIP classifies each detection into 6 types (if use_clip=True) 3. OCR extracts text from each detection (if extract_text=True) 4. BLIP generates descriptions for icons (if use_blip=True) 5. Global OCR merge attaches stray text to nearest detections (if merge_global_ocr=True) Args: image: Input image (path, PIL Image, or numpy array) confidence_threshold: Minimum confidence for RF-DETR detections extract_text: Whether to run OCR on detections use_clip: Whether to classify detections with CLIP use_blip: Whether to generate BLIP descriptions merge_global_ocr: Whether to run global OCR and merge results blip_scope: "icons" (only image/button) or "all" (all elements) preprocess: Enable image preprocessing (recommended for cross-device consistency) preprocess_mode: Preprocessing mode - 'rfdetr' (optimized for RF-DETR) or 'generic' (for CLIP/OCR) preprocess_preset: Preprocessing preset - depends on mode: - rfdetr mode: 'gentle', 'standard', 'aggressive_denoise', 'color_only' - generic mode: 'standard', 'aggressive', 'minimal', 'ocr_optimized' Returns: Dict with keys: - detections: List of {box, confidence, class_id, class_name, text, description} - image_size: {width, height} - preprocessed: Whether preprocessing was applied """ # Load image img_array = load_image(image) # Optional preprocessing for cross-device consistency preprocessed = False preprocessing_info = {} if preprocess: try: if preprocess_mode == "rfdetr": # RF-DETR optimized preprocessing (preserves ImageNet normalization) img_array = preprocess_for_rfdetr(img_array, preset=preprocess_preset) preprocessed = True preprocessing_info = { "mode": "rfdetr", "preset": preprocess_preset, "description": "RF-DETR optimized (preserves ImageNet normalization)" } elif preprocess_mode == "generic": # Generic preprocessing (for CLIP/OCR optimization) img_array = preprocess_screenshot(img_array, preset=preprocess_preset) preprocessed = True preprocessing_info = { "mode": "generic", "preset": preprocess_preset, "description": "Generic preprocessing (CLIP/OCR optimized)" } else: print(f"Warning: Unknown preprocess_mode '{preprocess_mode}'. Using 'rfdetr'.") img_array = preprocess_for_rfdetr(img_array, preset="standard") preprocessed = True preprocessing_info = { "mode": "rfdetr", "preset": "standard", "description": "RF-DETR optimized (fallback)" } except Exception as e: print(f"Warning: Preprocessing failed: {e}. Continuing with original image.") preprocessed = False preprocessing_info = {"error": str(e)} height, width = img_array.shape[:2] # RF-DETR Detection: Detects generic UI elements (SINGLE CLASS ONLY) det = self.model.predict(img_array, threshold=confidence_threshold) boxes = det.xyxy.tolist() scores = det.confidence.tolist() detections: List[Dict] = [] for box, score in zip(boxes, scores): x1, y1, x2, y2 = map(int, box) cropped = img_array[y1:y2, x1:x2] # CLIP Classification: Classify RF-DETR detection into one of 6 types if use_clip and self.enable_clip: predicted_class_id = self._classify_with_clip(cropped) class_name = self.CLASSES[predicted_class_id] if 0 <= predicted_class_id < len(self.CLASSES) else "unknown" else: predicted_class_id = None class_name = "" # OCR text extraction per detection text = self._extract_text(cropped) if extract_text and self.enable_ocr else "" # BLIP description per detection (keep separate from text) description = "" if use_blip and self.enable_blip and ( blip_scope == "all" or class_name in {"image", "button"} ): description = self._describe_with_blip(cropped) detections.append({ "box": {"x1": float(x1), "y1": float(y1), "x2": float(x2), "y2": float(y2)}, "confidence": float(score), "class_id": predicted_class_id, "class_name": class_name, "text": text, "description": description, }) # Optional global OCR merge: attach stray OCR to nearest detection if merge_global_ocr and extract_text and self.enable_ocr: try: self._load_ocr() # detail=1 returns [ [ (x,y)...4 points ], text, conf ] global_ocr = self.ocr_reader.readtext(img_array, detail=1) # Precompute detection boxes as tuples det_boxes: List[Tuple[int, int, int, int]] = [] for d in detections: b = d["box"] det_boxes.append((int(b["x1"]), int(b["y1"]), int(b["x2"]), int(b["y2"])) ) for entry in global_ocr: if not isinstance(entry, (list, tuple)) or len(entry) < 2: continue quad = entry[0] text = entry[1] if isinstance(entry[1], str) else "" if not text: continue # Convert quadrilateral to bounding box xs = [p[0] for p in quad] ys = [p[1] for p in quad] obox = (int(min(xs)), int(min(ys)), int(max(xs)), int(max(ys))) # Overlap with existing detections (IoU >= 0.1) → attach to best-overlap detection overlaps = [self._iou(obox, db) for db in det_boxes] if overlaps: max_iou = max(overlaps) if max_iou >= 0.1: best_overlap_idx = int(np.argmax(np.array(overlaps))) existing = detections[best_overlap_idx]["text"].strip() if text not in existing: detections[best_overlap_idx]["text"] = ( existing + (" " if existing else "") + text ).strip() # Attached to overlapping detection; proceed to next OCR entry continue # No sufficient overlap → find nearest detection by center distance ox, oy = self._box_center(obox) best_idx = -1 best_dist = float("inf") for idx, dbox in enumerate(det_boxes): cx, cy = self._box_center(dbox) dx = cx - ox dy = cy - oy dist2 = dx * dx + dy * dy if dist2 < best_dist: best_dist = dist2 best_idx = idx if best_idx >= 0: # Conservative distance threshold: within 0.3 of detection diagonal bx1, by1, bx2, by2 = det_boxes[best_idx] bw = max(1, bx2 - bx1) bh = max(1, by2 - by1) diag2 = bw * bw + bh * bh if best_dist <= 0.09 * diag2: # (0.3 * diag)^2 existing = detections[best_idx]["text"].strip() if text not in existing: detections[best_idx]["text"] = ( existing + (" " if existing else "") + text ).strip() continue # Not overlapping or near any detection → create a new OCR-only detection new_det = { "box": { "x1": float(obox[0]), "y1": float(obox[1]), "x2": float(obox[2]), "y2": float(obox[3]), }, "confidence": float(entry[2]) if len(entry) > 2 and entry[2] is not None else 1.0, "class_id": None, "class_name": "", "text": text.strip(), "description": "", } detections.append(new_det) det_boxes.append(obox) except Exception as e: print(f"Global OCR merge error: {e}") return { "detections": detections, "image_size": {"width": int(width), "height": int(height)}, "preprocessed": preprocessed, "preprocessing_info": preprocessing_info if preprocessed else None } def _draw_detections( self, image: np.ndarray, boxes: List[List[float]], scores: List[float], classes: List[int], contents: Optional[List[str]] = None, thickness: int = 3, font_scale: float = 0.5 ) -> np.ndarray: """Draw detection boxes and labels on image""" img_with_boxes = image.copy() for idx, (box, score, cls_id) in enumerate(zip(boxes, scores, classes)): x1, y1, x2, y2 = map(int, box) # Draw rectangle cv2.rectangle(img_with_boxes, (x1, y1), (x2, y2), self.BOX_COLOR, thickness) # Prepare label with confidence score label = f"{score:.2f}" # Add content if available content = "" if contents and idx < len(contents) and contents[idx]: content = contents[idx] # Truncate long content for display if len(content) > 40: content = content[:37] + "..." # Calculate label size and position (label_width, label_height), baseline = cv2.getTextSize( label, cv2.FONT_HERSHEY_SIMPLEX, font_scale, thickness=2 ) # Draw label background label_y = max(y1 - 10, label_height + 10) cv2.rectangle( img_with_boxes, (x1, label_y - label_height - baseline - 5), (x1 + label_width + 5, label_y + baseline - 5), self.BOX_COLOR, -1 ) # Draw label text (confidence score) cv2.putText( img_with_boxes, label, (x1 + 2, label_y - baseline - 5), cv2.FONT_HERSHEY_SIMPLEX, font_scale, (255, 255, 255), thickness=2 ) # Draw content text below the box if available if content: content_font_scale = font_scale * 0.8 (content_width, content_height), content_baseline = cv2.getTextSize( content, cv2.FONT_HERSHEY_SIMPLEX, content_font_scale, thickness=1 ) # Position content below the bottom of the box content_y = min(y2 + content_height + 15, img_with_boxes.shape[0] - 5) # Draw content background cv2.rectangle( img_with_boxes, (x1, content_y - content_height - content_baseline - 3), (x1 + content_width + 5, content_y + content_baseline), (0, 180, 0), # Slightly darker green -1 ) # Draw content text cv2.putText( img_with_boxes, content, (x1 + 2, content_y - content_baseline - 3), cv2.FONT_HERSHEY_SIMPLEX, content_font_scale, (255, 255, 255), thickness=1 ) return img_with_boxes @torch.inference_mode() def get_prediction_image( self, image: Union[str, Path, np.ndarray, Image.Image], confidence_threshold: float = 0.35, extract_content: bool = True, thickness: int = 3, font_scale: float = 0.5, return_format: str = "pil", analysis: Optional[Dict] = None ) -> Union[Image.Image, np.ndarray]: """ Get annotated image with detection boxes drawn Args: image: Input image (path, PIL Image, or numpy array) confidence_threshold: Minimum confidence score for detections (0.0-1.0) extract_content: Whether to extract and display text content or icon descriptions thickness: Thickness of bounding box lines font_scale: Font scale for labels return_format: Return format - "pil" for PIL Image or "numpy" for numpy array analysis: Pre-computed analysis results (optional, for performance) Returns: Annotated image as PIL Image or numpy array (RGB) """ # Load image img_array = load_image(image) if analysis is None: analysis = self.analyze( image, confidence_threshold=confidence_threshold, extract_text=extract_content, use_clip=self.enable_clip, use_blip=self.enable_blip, merge_global_ocr=True ) boxes = [] scores = [] class_ids = [] contents = [] for det in analysis["detections"]: b = det["box"] boxes.append([b["x1"], b["y1"], b["x2"], b["y2"]]) scores.append(det["confidence"]) class_ids.append(det["class_id"] if det.get("class_id") is not None else 0) if extract_content: text = det.get("text") or "" desc = det.get("description") or "" contents.append(text if text else (f"[Icon: {desc}]" if desc else "")) # Convert to BGR for OpenCV img_bgr = cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR) # Draw detections annotated_img = self._draw_detections( img_bgr, boxes, scores, class_ids, contents if extract_content else None, thickness, font_scale ) # Convert back to RGB annotated_img_rgb = cv2.cvtColor(annotated_img, cv2.COLOR_BGR2RGB) # Return in requested format if return_format.lower() == "pil": return Image.fromarray(annotated_img_rgb) else: return annotated_img_rgb