Spaces:
Sleeping
Sleeping
File size: 26,924 Bytes
77da9e2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 |
"""
Detection Service - Core Business Logic
This module contains the main DetectionService class that handles UI element detection.
ARCHITECTURE:
-------------
This service uses a multi-model pipeline:
1. RF-DETR (Detection Transformer)
- Detects generic "UI elements" as a SINGLE CLASS
- Provides bounding boxes and confidence scores
- Does NOT distinguish between button, input, text, etc.
2. CLIP (OpenAI)
- OPTIONAL multi-class classification
- Takes RF-DETR detections and classifies them into 6 types:
* button, input, text, image, list_item, navigation
- Only runs if enable_clip=True
3. EasyOCR
- Extracts text content from detected regions
- Runs global OCR merge to catch text outside detection boxes
4. BLIP (Salesforce)
- OPTIONAL visual description generation
- Describes icons and images when text is not present
- Only runs if enable_blip=True
Usage:
from detection.service import DetectionService
service = DetectionService()
results = service.analyze(image_path)
"""
import os
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'
import torch
import cv2
import numpy as np
from PIL import Image
from typing import Union, List, Dict, Tuple, Optional
from pathlib import Path
from rfdetr.detr import RFDETRMedium
import easyocr
from transformers import BlipProcessor, BlipForConditionalGeneration, CLIPProcessor, CLIPModel
from detection.image_utils import load_image
from detection.image_preprocessing import preprocess_screenshot, PRESETS
from detection.rfdetr_preprocessing import preprocess_for_rfdetr, RFDETR_PRESETS
class DetectionService:
"""
Detection Service for UI Element Detection
Provides a complete pipeline for detecting and analyzing UI elements in screenshots.
Uses RF-DETR for detection (single class), CLIP for classification (6 classes),
OCR for text extraction, and BLIP for visual descriptions.
"""
# UI Element classes - Optimized for Mobile Apps
# NOTE: These are NOT detected by RF-DETR (single class only)
# CLIP classifies RF-DETR detections into these 6 types
CLASSES = [
'button', # Buttons, FAB, chips, switches
'input', # Text fields, search bars
'text', # Labels, titles, paragraphs, descriptions
'image', # Images, icons, avatars, illustrations
'list_item', # List items, cards, tiles
'navigation' # Bottom nav, tabs, app bars, menus
]
# Default box color (BGR format for OpenCV)
BOX_COLOR = (0, 255, 0) # Green
def __init__(self, model_path: str = "model.pth", enable_ocr: bool = True, enable_blip: bool = True, enable_clip: bool = True):
"""
Initialize the Detection Service
Args:
model_path: Path to the RF-DETR model weights
enable_ocr: Whether to enable OCR for text extraction
enable_blip: Whether to enable BLIP for icon description
enable_clip: Whether to enable CLIP for UI element classification
"""
self.model_path = model_path
self.enable_ocr = enable_ocr
self.enable_blip = enable_blip
self.enable_clip = enable_clip
self.model = None
self.ocr_reader = None
self.blip_processor = None
self.blip_model = None
self.clip_processor = None
self.clip_model = None
# Load the detection model immediately
self._load_detection_model()
def _load_detection_model(self):
"""Load RF-DETR model (single-class UI element detector)"""
if self.model is None:
print("Loading RF-DETR model...")
kwargs = {"pretrain_weights": self.model_path}
custom_resolution = os.getenv("RFDETR_RESOLUTION")
if custom_resolution:
try:
kwargs["resolution"] = int(custom_resolution)
print(f"Using custom RF-DETR resolution: {kwargs['resolution']}")
except ValueError:
print(f"Warning: invalid RFDETR_RESOLUTION '{custom_resolution}'. Falling back to model default.")
else:
kwargs["resolution"] = 1600 # Default tuned for CU-1 deployment
self.model = RFDETRMedium(**kwargs)
print("RF-DETR model loaded successfully!")
def _load_ocr(self):
"""Load EasyOCR reader for text extraction"""
if self.enable_ocr and self.ocr_reader is None:
print("Loading OCR reader...")
self.ocr_reader = easyocr.Reader(['en', 'fr'], gpu=torch.cuda.is_available())
print("OCR reader loaded successfully!")
def _load_blip(self):
"""Load BLIP model for image captioning"""
if self.enable_blip and (self.blip_processor is None or self.blip_model is None):
print("Loading BLIP model for icon description...")
self.blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
# Use safetensors format to avoid torch.load vulnerability (CVE-2025-32434)
self.blip_model = BlipForConditionalGeneration.from_pretrained(
"Salesforce/blip-image-captioning-base",
use_safetensors=True
)
if torch.cuda.is_available():
self.blip_model = self.blip_model.to("cuda")
print("BLIP model loaded successfully!")
def _load_clip(self):
"""Load CLIP model for UI element classification"""
if self.enable_clip and (self.clip_processor is None or self.clip_model is None):
print("Loading CLIP model for UI element classification...")
self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
# Use safetensors format to avoid torch.load vulnerability (CVE-2025-32434)
self.clip_model = CLIPModel.from_pretrained(
"openai/clip-vit-base-patch32",
use_safetensors=True
)
if torch.cuda.is_available():
self.clip_model = self.clip_model.to("cuda")
print("CLIP model loaded successfully!")
def _classify_with_clip(self, cropped_img: np.ndarray) -> int:
"""
Classify UI element using CLIP
Args:
cropped_img: Cropped numpy array of the UI element
Returns:
Predicted class_id (0-5 corresponding to CLASSES)
"""
if cropped_img.size == 0:
return 0 # Default to first class
if not self.enable_clip:
return 0 # No classification, return default
self._load_clip()
try:
# Convert numpy array to PIL Image
pil_img = Image.fromarray(cropped_img)
# Create text prompts for each class - Optimized for mobile UI
text_prompts = [
"a mobile app button or interactive element",
"a text input field or search bar in a mobile app",
"text label, heading, or paragraph in a mobile app",
"an image, icon, or avatar in a mobile app",
"a list item, card, or tile in a mobile app",
"a navigation bar, tab, or menu in a mobile app"
]
# Process with CLIP
inputs = self.clip_processor(
text=text_prompts,
images=pil_img,
return_tensors="pt",
padding=True
)
if torch.cuda.is_available():
inputs = {k: v.to("cuda") for k, v in inputs.items()}
# Get predictions
outputs = self.clip_model(**inputs)
logits_per_image = outputs.logits_per_image
probs = logits_per_image.softmax(dim=1)
# Get the class with highest probability
predicted_class_id = probs.argmax().item()
return predicted_class_id
except Exception as clip_error:
print(f"CLIP classification error: {clip_error}")
return 0 # Fallback to default class
def _extract_text(self, cropped_img: np.ndarray) -> str:
"""Extract plain text from a cropped region using OCR (no BLIP)."""
if not self.enable_ocr or cropped_img.size == 0:
return ""
self._load_ocr()
try:
ocr_results = self.ocr_reader.readtext(cropped_img, detail=0)
return " ".join(ocr_results).strip()
except Exception as ocr_error:
print(f"OCR error: {ocr_error}")
return ""
def _describe_with_blip(self, cropped_img: np.ndarray) -> str:
"""Generate a visual description using BLIP for a cropped region."""
if not self.enable_blip or cropped_img.size == 0:
return ""
self._load_blip()
try:
pil_img = Image.fromarray(cropped_img)
inputs = self.blip_processor(pil_img, return_tensors="pt")
if torch.cuda.is_available():
inputs = {k: v.to("cuda") for k, v in inputs.items()}
out = self.blip_model.generate(**inputs, max_length=50)
return self.blip_processor.decode(out[0], skip_special_tokens=True)
except Exception as blip_error:
print(f"BLIP error: {blip_error}")
return ""
@staticmethod
def _iou(box_a: Tuple[int, int, int, int], box_b: Tuple[int, int, int, int]) -> float:
"""Calculate Intersection over Union between two boxes"""
xA = max(box_a[0], box_b[0])
yA = max(box_a[1], box_b[1])
xB = min(box_a[2], box_b[2])
yB = min(box_a[3], box_b[3])
inter_w = max(0, xB - xA)
inter_h = max(0, yB - yA)
inter_area = inter_w * inter_h
if inter_area == 0:
return 0.0
box_a_area = max(0, (box_a[2] - box_a[0])) * max(0, (box_a[3] - box_a[1]))
box_b_area = max(0, (box_b[2] - box_b[0])) * max(0, (box_b[3] - box_b[1]))
union = box_a_area + box_b_area - inter_area
if union <= 0:
return 0.0
return inter_area / union
@staticmethod
def _box_center(box: Tuple[int, int, int, int]) -> Tuple[float, float]:
"""Calculate the center point of a bounding box"""
x1, y1, x2, y2 = box
return (x1 + x2) / 2.0, (y1 + y2) / 2.0
@torch.inference_mode()
def analyze(
self,
image: Union[str, Path, np.ndarray, Image.Image],
confidence_threshold: float = 0.35,
extract_text: bool = True,
use_clip: bool = True,
use_blip: bool = False,
merge_global_ocr: bool = True,
blip_scope: str = "icons",
preprocess: bool = False,
preprocess_preset: str = "standard",
preprocess_mode: str = "rfdetr"
) -> Dict:
"""
Run a single-pass analysis: detection, optional CLIP classification, OCR, optional BLIP,
and optional global OCR merge into nearest detection.
PIPELINE:
0. Optional preprocessing (normalize colors, contrast, denoise)
1. RF-DETR detects UI elements (single class - just bounding boxes)
2. CLIP classifies each detection into 6 types (if use_clip=True)
3. OCR extracts text from each detection (if extract_text=True)
4. BLIP generates descriptions for icons (if use_blip=True)
5. Global OCR merge attaches stray text to nearest detections (if merge_global_ocr=True)
Args:
image: Input image (path, PIL Image, or numpy array)
confidence_threshold: Minimum confidence for RF-DETR detections
extract_text: Whether to run OCR on detections
use_clip: Whether to classify detections with CLIP
use_blip: Whether to generate BLIP descriptions
merge_global_ocr: Whether to run global OCR and merge results
blip_scope: "icons" (only image/button) or "all" (all elements)
preprocess: Enable image preprocessing (recommended for cross-device consistency)
preprocess_mode: Preprocessing mode - 'rfdetr' (optimized for RF-DETR) or 'generic' (for CLIP/OCR)
preprocess_preset: Preprocessing preset - depends on mode:
- rfdetr mode: 'gentle', 'standard', 'aggressive_denoise', 'color_only'
- generic mode: 'standard', 'aggressive', 'minimal', 'ocr_optimized'
Returns:
Dict with keys:
- detections: List of {box, confidence, class_id, class_name, text, description}
- image_size: {width, height}
- preprocessed: Whether preprocessing was applied
"""
# Load image
img_array = load_image(image)
# Optional preprocessing for cross-device consistency
preprocessed = False
preprocessing_info = {}
if preprocess:
try:
if preprocess_mode == "rfdetr":
# RF-DETR optimized preprocessing (preserves ImageNet normalization)
img_array = preprocess_for_rfdetr(img_array, preset=preprocess_preset)
preprocessed = True
preprocessing_info = {
"mode": "rfdetr",
"preset": preprocess_preset,
"description": "RF-DETR optimized (preserves ImageNet normalization)"
}
elif preprocess_mode == "generic":
# Generic preprocessing (for CLIP/OCR optimization)
img_array = preprocess_screenshot(img_array, preset=preprocess_preset)
preprocessed = True
preprocessing_info = {
"mode": "generic",
"preset": preprocess_preset,
"description": "Generic preprocessing (CLIP/OCR optimized)"
}
else:
print(f"Warning: Unknown preprocess_mode '{preprocess_mode}'. Using 'rfdetr'.")
img_array = preprocess_for_rfdetr(img_array, preset="standard")
preprocessed = True
preprocessing_info = {
"mode": "rfdetr",
"preset": "standard",
"description": "RF-DETR optimized (fallback)"
}
except Exception as e:
print(f"Warning: Preprocessing failed: {e}. Continuing with original image.")
preprocessed = False
preprocessing_info = {"error": str(e)}
height, width = img_array.shape[:2]
# RF-DETR Detection: Detects generic UI elements (SINGLE CLASS ONLY)
det = self.model.predict(img_array, threshold=confidence_threshold)
boxes = det.xyxy.tolist()
scores = det.confidence.tolist()
detections: List[Dict] = []
for box, score in zip(boxes, scores):
x1, y1, x2, y2 = map(int, box)
cropped = img_array[y1:y2, x1:x2]
# CLIP Classification: Classify RF-DETR detection into one of 6 types
if use_clip and self.enable_clip:
predicted_class_id = self._classify_with_clip(cropped)
class_name = self.CLASSES[predicted_class_id] if 0 <= predicted_class_id < len(self.CLASSES) else "unknown"
else:
predicted_class_id = None
class_name = ""
# OCR text extraction per detection
text = self._extract_text(cropped) if extract_text and self.enable_ocr else ""
# BLIP description per detection (keep separate from text)
description = ""
if use_blip and self.enable_blip and (
blip_scope == "all" or class_name in {"image", "button"}
):
description = self._describe_with_blip(cropped)
detections.append({
"box": {"x1": float(x1), "y1": float(y1), "x2": float(x2), "y2": float(y2)},
"confidence": float(score),
"class_id": predicted_class_id,
"class_name": class_name,
"text": text,
"description": description,
})
# Optional global OCR merge: attach stray OCR to nearest detection
if merge_global_ocr and extract_text and self.enable_ocr:
try:
self._load_ocr()
# detail=1 returns [ [ (x,y)...4 points ], text, conf ]
global_ocr = self.ocr_reader.readtext(img_array, detail=1)
# Precompute detection boxes as tuples
det_boxes: List[Tuple[int, int, int, int]] = []
for d in detections:
b = d["box"]
det_boxes.append((int(b["x1"]), int(b["y1"]), int(b["x2"]), int(b["y2"])) )
for entry in global_ocr:
if not isinstance(entry, (list, tuple)) or len(entry) < 2:
continue
quad = entry[0]
text = entry[1] if isinstance(entry[1], str) else ""
if not text:
continue
# Convert quadrilateral to bounding box
xs = [p[0] for p in quad]
ys = [p[1] for p in quad]
obox = (int(min(xs)), int(min(ys)), int(max(xs)), int(max(ys)))
# Overlap with existing detections (IoU >= 0.1) → attach to best-overlap detection
overlaps = [self._iou(obox, db) for db in det_boxes]
if overlaps:
max_iou = max(overlaps)
if max_iou >= 0.1:
best_overlap_idx = int(np.argmax(np.array(overlaps)))
existing = detections[best_overlap_idx]["text"].strip()
if text not in existing:
detections[best_overlap_idx]["text"] = (
existing + (" " if existing else "") + text
).strip()
# Attached to overlapping detection; proceed to next OCR entry
continue
# No sufficient overlap → find nearest detection by center distance
ox, oy = self._box_center(obox)
best_idx = -1
best_dist = float("inf")
for idx, dbox in enumerate(det_boxes):
cx, cy = self._box_center(dbox)
dx = cx - ox
dy = cy - oy
dist2 = dx * dx + dy * dy
if dist2 < best_dist:
best_dist = dist2
best_idx = idx
if best_idx >= 0:
# Conservative distance threshold: within 0.3 of detection diagonal
bx1, by1, bx2, by2 = det_boxes[best_idx]
bw = max(1, bx2 - bx1)
bh = max(1, by2 - by1)
diag2 = bw * bw + bh * bh
if best_dist <= 0.09 * diag2: # (0.3 * diag)^2
existing = detections[best_idx]["text"].strip()
if text not in existing:
detections[best_idx]["text"] = (
existing + (" " if existing else "") + text
).strip()
continue
# Not overlapping or near any detection → create a new OCR-only detection
new_det = {
"box": {
"x1": float(obox[0]),
"y1": float(obox[1]),
"x2": float(obox[2]),
"y2": float(obox[3]),
},
"confidence": float(entry[2]) if len(entry) > 2 and entry[2] is not None else 1.0,
"class_id": None,
"class_name": "",
"text": text.strip(),
"description": "",
}
detections.append(new_det)
det_boxes.append(obox)
except Exception as e:
print(f"Global OCR merge error: {e}")
return {
"detections": detections,
"image_size": {"width": int(width), "height": int(height)},
"preprocessed": preprocessed,
"preprocessing_info": preprocessing_info if preprocessed else None
}
def _draw_detections(
self,
image: np.ndarray,
boxes: List[List[float]],
scores: List[float],
classes: List[int],
contents: Optional[List[str]] = None,
thickness: int = 3,
font_scale: float = 0.5
) -> np.ndarray:
"""Draw detection boxes and labels on image"""
img_with_boxes = image.copy()
for idx, (box, score, cls_id) in enumerate(zip(boxes, scores, classes)):
x1, y1, x2, y2 = map(int, box)
# Draw rectangle
cv2.rectangle(img_with_boxes, (x1, y1), (x2, y2), self.BOX_COLOR, thickness)
# Prepare label with confidence score
label = f"{score:.2f}"
# Add content if available
content = ""
if contents and idx < len(contents) and contents[idx]:
content = contents[idx]
# Truncate long content for display
if len(content) > 40:
content = content[:37] + "..."
# Calculate label size and position
(label_width, label_height), baseline = cv2.getTextSize(
label, cv2.FONT_HERSHEY_SIMPLEX, font_scale, thickness=2
)
# Draw label background
label_y = max(y1 - 10, label_height + 10)
cv2.rectangle(
img_with_boxes,
(x1, label_y - label_height - baseline - 5),
(x1 + label_width + 5, label_y + baseline - 5),
self.BOX_COLOR,
-1
)
# Draw label text (confidence score)
cv2.putText(
img_with_boxes,
label,
(x1 + 2, label_y - baseline - 5),
cv2.FONT_HERSHEY_SIMPLEX,
font_scale,
(255, 255, 255),
thickness=2
)
# Draw content text below the box if available
if content:
content_font_scale = font_scale * 0.8
(content_width, content_height), content_baseline = cv2.getTextSize(
content, cv2.FONT_HERSHEY_SIMPLEX, content_font_scale, thickness=1
)
# Position content below the bottom of the box
content_y = min(y2 + content_height + 15, img_with_boxes.shape[0] - 5)
# Draw content background
cv2.rectangle(
img_with_boxes,
(x1, content_y - content_height - content_baseline - 3),
(x1 + content_width + 5, content_y + content_baseline),
(0, 180, 0), # Slightly darker green
-1
)
# Draw content text
cv2.putText(
img_with_boxes,
content,
(x1 + 2, content_y - content_baseline - 3),
cv2.FONT_HERSHEY_SIMPLEX,
content_font_scale,
(255, 255, 255),
thickness=1
)
return img_with_boxes
@torch.inference_mode()
def get_prediction_image(
self,
image: Union[str, Path, np.ndarray, Image.Image],
confidence_threshold: float = 0.35,
extract_content: bool = True,
thickness: int = 3,
font_scale: float = 0.5,
return_format: str = "pil",
analysis: Optional[Dict] = None
) -> Union[Image.Image, np.ndarray]:
"""
Get annotated image with detection boxes drawn
Args:
image: Input image (path, PIL Image, or numpy array)
confidence_threshold: Minimum confidence score for detections (0.0-1.0)
extract_content: Whether to extract and display text content or icon descriptions
thickness: Thickness of bounding box lines
font_scale: Font scale for labels
return_format: Return format - "pil" for PIL Image or "numpy" for numpy array
analysis: Pre-computed analysis results (optional, for performance)
Returns:
Annotated image as PIL Image or numpy array (RGB)
"""
# Load image
img_array = load_image(image)
if analysis is None:
analysis = self.analyze(
image,
confidence_threshold=confidence_threshold,
extract_text=extract_content,
use_clip=self.enable_clip,
use_blip=self.enable_blip,
merge_global_ocr=True
)
boxes = []
scores = []
class_ids = []
contents = []
for det in analysis["detections"]:
b = det["box"]
boxes.append([b["x1"], b["y1"], b["x2"], b["y2"]])
scores.append(det["confidence"])
class_ids.append(det["class_id"] if det.get("class_id") is not None else 0)
if extract_content:
text = det.get("text") or ""
desc = det.get("description") or ""
contents.append(text if text else (f"[Icon: {desc}]" if desc else ""))
# Convert to BGR for OpenCV
img_bgr = cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR)
# Draw detections
annotated_img = self._draw_detections(
img_bgr, boxes, scores, class_ids,
contents if extract_content else None,
thickness, font_scale
)
# Convert back to RGB
annotated_img_rgb = cv2.cvtColor(annotated_img, cv2.COLOR_BGR2RGB)
# Return in requested format
if return_format.lower() == "pil":
return Image.fromarray(annotated_img_rgb)
else:
return annotated_img_rgb
|