CU1-X / detection /service.py
AI-DrivenTesting's picture
init
77da9e2
"""
Detection Service - Core Business Logic
This module contains the main DetectionService class that handles UI element detection.
ARCHITECTURE:
-------------
This service uses a multi-model pipeline:
1. RF-DETR (Detection Transformer)
- Detects generic "UI elements" as a SINGLE CLASS
- Provides bounding boxes and confidence scores
- Does NOT distinguish between button, input, text, etc.
2. CLIP (OpenAI)
- OPTIONAL multi-class classification
- Takes RF-DETR detections and classifies them into 6 types:
* button, input, text, image, list_item, navigation
- Only runs if enable_clip=True
3. EasyOCR
- Extracts text content from detected regions
- Runs global OCR merge to catch text outside detection boxes
4. BLIP (Salesforce)
- OPTIONAL visual description generation
- Describes icons and images when text is not present
- Only runs if enable_blip=True
Usage:
from detection.service import DetectionService
service = DetectionService()
results = service.analyze(image_path)
"""
import os
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'
import torch
import cv2
import numpy as np
from PIL import Image
from typing import Union, List, Dict, Tuple, Optional
from pathlib import Path
from rfdetr.detr import RFDETRMedium
import easyocr
from transformers import BlipProcessor, BlipForConditionalGeneration, CLIPProcessor, CLIPModel
from detection.image_utils import load_image
from detection.image_preprocessing import preprocess_screenshot, PRESETS
from detection.rfdetr_preprocessing import preprocess_for_rfdetr, RFDETR_PRESETS
class DetectionService:
"""
Detection Service for UI Element Detection
Provides a complete pipeline for detecting and analyzing UI elements in screenshots.
Uses RF-DETR for detection (single class), CLIP for classification (6 classes),
OCR for text extraction, and BLIP for visual descriptions.
"""
# UI Element classes - Optimized for Mobile Apps
# NOTE: These are NOT detected by RF-DETR (single class only)
# CLIP classifies RF-DETR detections into these 6 types
CLASSES = [
'button', # Buttons, FAB, chips, switches
'input', # Text fields, search bars
'text', # Labels, titles, paragraphs, descriptions
'image', # Images, icons, avatars, illustrations
'list_item', # List items, cards, tiles
'navigation' # Bottom nav, tabs, app bars, menus
]
# Default box color (BGR format for OpenCV)
BOX_COLOR = (0, 255, 0) # Green
def __init__(self, model_path: str = "model.pth", enable_ocr: bool = True, enable_blip: bool = True, enable_clip: bool = True):
"""
Initialize the Detection Service
Args:
model_path: Path to the RF-DETR model weights
enable_ocr: Whether to enable OCR for text extraction
enable_blip: Whether to enable BLIP for icon description
enable_clip: Whether to enable CLIP for UI element classification
"""
self.model_path = model_path
self.enable_ocr = enable_ocr
self.enable_blip = enable_blip
self.enable_clip = enable_clip
self.model = None
self.ocr_reader = None
self.blip_processor = None
self.blip_model = None
self.clip_processor = None
self.clip_model = None
# Load the detection model immediately
self._load_detection_model()
def _load_detection_model(self):
"""Load RF-DETR model (single-class UI element detector)"""
if self.model is None:
print("Loading RF-DETR model...")
kwargs = {"pretrain_weights": self.model_path}
custom_resolution = os.getenv("RFDETR_RESOLUTION")
if custom_resolution:
try:
kwargs["resolution"] = int(custom_resolution)
print(f"Using custom RF-DETR resolution: {kwargs['resolution']}")
except ValueError:
print(f"Warning: invalid RFDETR_RESOLUTION '{custom_resolution}'. Falling back to model default.")
else:
kwargs["resolution"] = 1600 # Default tuned for CU-1 deployment
self.model = RFDETRMedium(**kwargs)
print("RF-DETR model loaded successfully!")
def _load_ocr(self):
"""Load EasyOCR reader for text extraction"""
if self.enable_ocr and self.ocr_reader is None:
print("Loading OCR reader...")
self.ocr_reader = easyocr.Reader(['en', 'fr'], gpu=torch.cuda.is_available())
print("OCR reader loaded successfully!")
def _load_blip(self):
"""Load BLIP model for image captioning"""
if self.enable_blip and (self.blip_processor is None or self.blip_model is None):
print("Loading BLIP model for icon description...")
self.blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
# Use safetensors format to avoid torch.load vulnerability (CVE-2025-32434)
self.blip_model = BlipForConditionalGeneration.from_pretrained(
"Salesforce/blip-image-captioning-base",
use_safetensors=True
)
if torch.cuda.is_available():
self.blip_model = self.blip_model.to("cuda")
print("BLIP model loaded successfully!")
def _load_clip(self):
"""Load CLIP model for UI element classification"""
if self.enable_clip and (self.clip_processor is None or self.clip_model is None):
print("Loading CLIP model for UI element classification...")
self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
# Use safetensors format to avoid torch.load vulnerability (CVE-2025-32434)
self.clip_model = CLIPModel.from_pretrained(
"openai/clip-vit-base-patch32",
use_safetensors=True
)
if torch.cuda.is_available():
self.clip_model = self.clip_model.to("cuda")
print("CLIP model loaded successfully!")
def _classify_with_clip(self, cropped_img: np.ndarray) -> int:
"""
Classify UI element using CLIP
Args:
cropped_img: Cropped numpy array of the UI element
Returns:
Predicted class_id (0-5 corresponding to CLASSES)
"""
if cropped_img.size == 0:
return 0 # Default to first class
if not self.enable_clip:
return 0 # No classification, return default
self._load_clip()
try:
# Convert numpy array to PIL Image
pil_img = Image.fromarray(cropped_img)
# Create text prompts for each class - Optimized for mobile UI
text_prompts = [
"a mobile app button or interactive element",
"a text input field or search bar in a mobile app",
"text label, heading, or paragraph in a mobile app",
"an image, icon, or avatar in a mobile app",
"a list item, card, or tile in a mobile app",
"a navigation bar, tab, or menu in a mobile app"
]
# Process with CLIP
inputs = self.clip_processor(
text=text_prompts,
images=pil_img,
return_tensors="pt",
padding=True
)
if torch.cuda.is_available():
inputs = {k: v.to("cuda") for k, v in inputs.items()}
# Get predictions
outputs = self.clip_model(**inputs)
logits_per_image = outputs.logits_per_image
probs = logits_per_image.softmax(dim=1)
# Get the class with highest probability
predicted_class_id = probs.argmax().item()
return predicted_class_id
except Exception as clip_error:
print(f"CLIP classification error: {clip_error}")
return 0 # Fallback to default class
def _extract_text(self, cropped_img: np.ndarray) -> str:
"""Extract plain text from a cropped region using OCR (no BLIP)."""
if not self.enable_ocr or cropped_img.size == 0:
return ""
self._load_ocr()
try:
ocr_results = self.ocr_reader.readtext(cropped_img, detail=0)
return " ".join(ocr_results).strip()
except Exception as ocr_error:
print(f"OCR error: {ocr_error}")
return ""
def _describe_with_blip(self, cropped_img: np.ndarray) -> str:
"""Generate a visual description using BLIP for a cropped region."""
if not self.enable_blip or cropped_img.size == 0:
return ""
self._load_blip()
try:
pil_img = Image.fromarray(cropped_img)
inputs = self.blip_processor(pil_img, return_tensors="pt")
if torch.cuda.is_available():
inputs = {k: v.to("cuda") for k, v in inputs.items()}
out = self.blip_model.generate(**inputs, max_length=50)
return self.blip_processor.decode(out[0], skip_special_tokens=True)
except Exception as blip_error:
print(f"BLIP error: {blip_error}")
return ""
@staticmethod
def _iou(box_a: Tuple[int, int, int, int], box_b: Tuple[int, int, int, int]) -> float:
"""Calculate Intersection over Union between two boxes"""
xA = max(box_a[0], box_b[0])
yA = max(box_a[1], box_b[1])
xB = min(box_a[2], box_b[2])
yB = min(box_a[3], box_b[3])
inter_w = max(0, xB - xA)
inter_h = max(0, yB - yA)
inter_area = inter_w * inter_h
if inter_area == 0:
return 0.0
box_a_area = max(0, (box_a[2] - box_a[0])) * max(0, (box_a[3] - box_a[1]))
box_b_area = max(0, (box_b[2] - box_b[0])) * max(0, (box_b[3] - box_b[1]))
union = box_a_area + box_b_area - inter_area
if union <= 0:
return 0.0
return inter_area / union
@staticmethod
def _box_center(box: Tuple[int, int, int, int]) -> Tuple[float, float]:
"""Calculate the center point of a bounding box"""
x1, y1, x2, y2 = box
return (x1 + x2) / 2.0, (y1 + y2) / 2.0
@torch.inference_mode()
def analyze(
self,
image: Union[str, Path, np.ndarray, Image.Image],
confidence_threshold: float = 0.35,
extract_text: bool = True,
use_clip: bool = True,
use_blip: bool = False,
merge_global_ocr: bool = True,
blip_scope: str = "icons",
preprocess: bool = False,
preprocess_preset: str = "standard",
preprocess_mode: str = "rfdetr"
) -> Dict:
"""
Run a single-pass analysis: detection, optional CLIP classification, OCR, optional BLIP,
and optional global OCR merge into nearest detection.
PIPELINE:
0. Optional preprocessing (normalize colors, contrast, denoise)
1. RF-DETR detects UI elements (single class - just bounding boxes)
2. CLIP classifies each detection into 6 types (if use_clip=True)
3. OCR extracts text from each detection (if extract_text=True)
4. BLIP generates descriptions for icons (if use_blip=True)
5. Global OCR merge attaches stray text to nearest detections (if merge_global_ocr=True)
Args:
image: Input image (path, PIL Image, or numpy array)
confidence_threshold: Minimum confidence for RF-DETR detections
extract_text: Whether to run OCR on detections
use_clip: Whether to classify detections with CLIP
use_blip: Whether to generate BLIP descriptions
merge_global_ocr: Whether to run global OCR and merge results
blip_scope: "icons" (only image/button) or "all" (all elements)
preprocess: Enable image preprocessing (recommended for cross-device consistency)
preprocess_mode: Preprocessing mode - 'rfdetr' (optimized for RF-DETR) or 'generic' (for CLIP/OCR)
preprocess_preset: Preprocessing preset - depends on mode:
- rfdetr mode: 'gentle', 'standard', 'aggressive_denoise', 'color_only'
- generic mode: 'standard', 'aggressive', 'minimal', 'ocr_optimized'
Returns:
Dict with keys:
- detections: List of {box, confidence, class_id, class_name, text, description}
- image_size: {width, height}
- preprocessed: Whether preprocessing was applied
"""
# Load image
img_array = load_image(image)
# Optional preprocessing for cross-device consistency
preprocessed = False
preprocessing_info = {}
if preprocess:
try:
if preprocess_mode == "rfdetr":
# RF-DETR optimized preprocessing (preserves ImageNet normalization)
img_array = preprocess_for_rfdetr(img_array, preset=preprocess_preset)
preprocessed = True
preprocessing_info = {
"mode": "rfdetr",
"preset": preprocess_preset,
"description": "RF-DETR optimized (preserves ImageNet normalization)"
}
elif preprocess_mode == "generic":
# Generic preprocessing (for CLIP/OCR optimization)
img_array = preprocess_screenshot(img_array, preset=preprocess_preset)
preprocessed = True
preprocessing_info = {
"mode": "generic",
"preset": preprocess_preset,
"description": "Generic preprocessing (CLIP/OCR optimized)"
}
else:
print(f"Warning: Unknown preprocess_mode '{preprocess_mode}'. Using 'rfdetr'.")
img_array = preprocess_for_rfdetr(img_array, preset="standard")
preprocessed = True
preprocessing_info = {
"mode": "rfdetr",
"preset": "standard",
"description": "RF-DETR optimized (fallback)"
}
except Exception as e:
print(f"Warning: Preprocessing failed: {e}. Continuing with original image.")
preprocessed = False
preprocessing_info = {"error": str(e)}
height, width = img_array.shape[:2]
# RF-DETR Detection: Detects generic UI elements (SINGLE CLASS ONLY)
det = self.model.predict(img_array, threshold=confidence_threshold)
boxes = det.xyxy.tolist()
scores = det.confidence.tolist()
detections: List[Dict] = []
for box, score in zip(boxes, scores):
x1, y1, x2, y2 = map(int, box)
cropped = img_array[y1:y2, x1:x2]
# CLIP Classification: Classify RF-DETR detection into one of 6 types
if use_clip and self.enable_clip:
predicted_class_id = self._classify_with_clip(cropped)
class_name = self.CLASSES[predicted_class_id] if 0 <= predicted_class_id < len(self.CLASSES) else "unknown"
else:
predicted_class_id = None
class_name = ""
# OCR text extraction per detection
text = self._extract_text(cropped) if extract_text and self.enable_ocr else ""
# BLIP description per detection (keep separate from text)
description = ""
if use_blip and self.enable_blip and (
blip_scope == "all" or class_name in {"image", "button"}
):
description = self._describe_with_blip(cropped)
detections.append({
"box": {"x1": float(x1), "y1": float(y1), "x2": float(x2), "y2": float(y2)},
"confidence": float(score),
"class_id": predicted_class_id,
"class_name": class_name,
"text": text,
"description": description,
})
# Optional global OCR merge: attach stray OCR to nearest detection
if merge_global_ocr and extract_text and self.enable_ocr:
try:
self._load_ocr()
# detail=1 returns [ [ (x,y)...4 points ], text, conf ]
global_ocr = self.ocr_reader.readtext(img_array, detail=1)
# Precompute detection boxes as tuples
det_boxes: List[Tuple[int, int, int, int]] = []
for d in detections:
b = d["box"]
det_boxes.append((int(b["x1"]), int(b["y1"]), int(b["x2"]), int(b["y2"])) )
for entry in global_ocr:
if not isinstance(entry, (list, tuple)) or len(entry) < 2:
continue
quad = entry[0]
text = entry[1] if isinstance(entry[1], str) else ""
if not text:
continue
# Convert quadrilateral to bounding box
xs = [p[0] for p in quad]
ys = [p[1] for p in quad]
obox = (int(min(xs)), int(min(ys)), int(max(xs)), int(max(ys)))
# Overlap with existing detections (IoU >= 0.1) → attach to best-overlap detection
overlaps = [self._iou(obox, db) for db in det_boxes]
if overlaps:
max_iou = max(overlaps)
if max_iou >= 0.1:
best_overlap_idx = int(np.argmax(np.array(overlaps)))
existing = detections[best_overlap_idx]["text"].strip()
if text not in existing:
detections[best_overlap_idx]["text"] = (
existing + (" " if existing else "") + text
).strip()
# Attached to overlapping detection; proceed to next OCR entry
continue
# No sufficient overlap → find nearest detection by center distance
ox, oy = self._box_center(obox)
best_idx = -1
best_dist = float("inf")
for idx, dbox in enumerate(det_boxes):
cx, cy = self._box_center(dbox)
dx = cx - ox
dy = cy - oy
dist2 = dx * dx + dy * dy
if dist2 < best_dist:
best_dist = dist2
best_idx = idx
if best_idx >= 0:
# Conservative distance threshold: within 0.3 of detection diagonal
bx1, by1, bx2, by2 = det_boxes[best_idx]
bw = max(1, bx2 - bx1)
bh = max(1, by2 - by1)
diag2 = bw * bw + bh * bh
if best_dist <= 0.09 * diag2: # (0.3 * diag)^2
existing = detections[best_idx]["text"].strip()
if text not in existing:
detections[best_idx]["text"] = (
existing + (" " if existing else "") + text
).strip()
continue
# Not overlapping or near any detection → create a new OCR-only detection
new_det = {
"box": {
"x1": float(obox[0]),
"y1": float(obox[1]),
"x2": float(obox[2]),
"y2": float(obox[3]),
},
"confidence": float(entry[2]) if len(entry) > 2 and entry[2] is not None else 1.0,
"class_id": None,
"class_name": "",
"text": text.strip(),
"description": "",
}
detections.append(new_det)
det_boxes.append(obox)
except Exception as e:
print(f"Global OCR merge error: {e}")
return {
"detections": detections,
"image_size": {"width": int(width), "height": int(height)},
"preprocessed": preprocessed,
"preprocessing_info": preprocessing_info if preprocessed else None
}
def _draw_detections(
self,
image: np.ndarray,
boxes: List[List[float]],
scores: List[float],
classes: List[int],
contents: Optional[List[str]] = None,
thickness: int = 3,
font_scale: float = 0.5
) -> np.ndarray:
"""Draw detection boxes and labels on image"""
img_with_boxes = image.copy()
for idx, (box, score, cls_id) in enumerate(zip(boxes, scores, classes)):
x1, y1, x2, y2 = map(int, box)
# Draw rectangle
cv2.rectangle(img_with_boxes, (x1, y1), (x2, y2), self.BOX_COLOR, thickness)
# Prepare label with confidence score
label = f"{score:.2f}"
# Add content if available
content = ""
if contents and idx < len(contents) and contents[idx]:
content = contents[idx]
# Truncate long content for display
if len(content) > 40:
content = content[:37] + "..."
# Calculate label size and position
(label_width, label_height), baseline = cv2.getTextSize(
label, cv2.FONT_HERSHEY_SIMPLEX, font_scale, thickness=2
)
# Draw label background
label_y = max(y1 - 10, label_height + 10)
cv2.rectangle(
img_with_boxes,
(x1, label_y - label_height - baseline - 5),
(x1 + label_width + 5, label_y + baseline - 5),
self.BOX_COLOR,
-1
)
# Draw label text (confidence score)
cv2.putText(
img_with_boxes,
label,
(x1 + 2, label_y - baseline - 5),
cv2.FONT_HERSHEY_SIMPLEX,
font_scale,
(255, 255, 255),
thickness=2
)
# Draw content text below the box if available
if content:
content_font_scale = font_scale * 0.8
(content_width, content_height), content_baseline = cv2.getTextSize(
content, cv2.FONT_HERSHEY_SIMPLEX, content_font_scale, thickness=1
)
# Position content below the bottom of the box
content_y = min(y2 + content_height + 15, img_with_boxes.shape[0] - 5)
# Draw content background
cv2.rectangle(
img_with_boxes,
(x1, content_y - content_height - content_baseline - 3),
(x1 + content_width + 5, content_y + content_baseline),
(0, 180, 0), # Slightly darker green
-1
)
# Draw content text
cv2.putText(
img_with_boxes,
content,
(x1 + 2, content_y - content_baseline - 3),
cv2.FONT_HERSHEY_SIMPLEX,
content_font_scale,
(255, 255, 255),
thickness=1
)
return img_with_boxes
@torch.inference_mode()
def get_prediction_image(
self,
image: Union[str, Path, np.ndarray, Image.Image],
confidence_threshold: float = 0.35,
extract_content: bool = True,
thickness: int = 3,
font_scale: float = 0.5,
return_format: str = "pil",
analysis: Optional[Dict] = None
) -> Union[Image.Image, np.ndarray]:
"""
Get annotated image with detection boxes drawn
Args:
image: Input image (path, PIL Image, or numpy array)
confidence_threshold: Minimum confidence score for detections (0.0-1.0)
extract_content: Whether to extract and display text content or icon descriptions
thickness: Thickness of bounding box lines
font_scale: Font scale for labels
return_format: Return format - "pil" for PIL Image or "numpy" for numpy array
analysis: Pre-computed analysis results (optional, for performance)
Returns:
Annotated image as PIL Image or numpy array (RGB)
"""
# Load image
img_array = load_image(image)
if analysis is None:
analysis = self.analyze(
image,
confidence_threshold=confidence_threshold,
extract_text=extract_content,
use_clip=self.enable_clip,
use_blip=self.enable_blip,
merge_global_ocr=True
)
boxes = []
scores = []
class_ids = []
contents = []
for det in analysis["detections"]:
b = det["box"]
boxes.append([b["x1"], b["y1"], b["x2"], b["y2"]])
scores.append(det["confidence"])
class_ids.append(det["class_id"] if det.get("class_id") is not None else 0)
if extract_content:
text = det.get("text") or ""
desc = det.get("description") or ""
contents.append(text if text else (f"[Icon: {desc}]" if desc else ""))
# Convert to BGR for OpenCV
img_bgr = cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR)
# Draw detections
annotated_img = self._draw_detections(
img_bgr, boxes, scores, class_ids,
contents if extract_content else None,
thickness, font_scale
)
# Convert back to RGB
annotated_img_rgb = cv2.cvtColor(annotated_img, cv2.COLOR_BGR2RGB)
# Return in requested format
if return_format.lower() == "pil":
return Image.fromarray(annotated_img_rgb)
else:
return annotated_img_rgb