File size: 26,924 Bytes
77da9e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
"""
Detection Service - Core Business Logic

This module contains the main DetectionService class that handles UI element detection.

ARCHITECTURE:
-------------
This service uses a multi-model pipeline:

1. RF-DETR (Detection Transformer)
   - Detects generic "UI elements" as a SINGLE CLASS
   - Provides bounding boxes and confidence scores
   - Does NOT distinguish between button, input, text, etc.

2. CLIP (OpenAI)
   - OPTIONAL multi-class classification
   - Takes RF-DETR detections and classifies them into 6 types:
     * button, input, text, image, list_item, navigation
   - Only runs if enable_clip=True

3. EasyOCR
   - Extracts text content from detected regions
   - Runs global OCR merge to catch text outside detection boxes

4. BLIP (Salesforce)
   - OPTIONAL visual description generation
   - Describes icons and images when text is not present
   - Only runs if enable_blip=True

Usage:
    from detection.service import DetectionService
    
    service = DetectionService()
    results = service.analyze(image_path)
"""

import os
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'

import torch
import cv2
import numpy as np
from PIL import Image
from typing import Union, List, Dict, Tuple, Optional
from pathlib import Path
from rfdetr.detr import RFDETRMedium
import easyocr
from transformers import BlipProcessor, BlipForConditionalGeneration, CLIPProcessor, CLIPModel

from detection.image_utils import load_image
from detection.image_preprocessing import preprocess_screenshot, PRESETS
from detection.rfdetr_preprocessing import preprocess_for_rfdetr, RFDETR_PRESETS


class DetectionService:
    """
    Detection Service for UI Element Detection
    
    Provides a complete pipeline for detecting and analyzing UI elements in screenshots.
    Uses RF-DETR for detection (single class), CLIP for classification (6 classes),
    OCR for text extraction, and BLIP for visual descriptions.
    """
    
    # UI Element classes - Optimized for Mobile Apps
    # NOTE: These are NOT detected by RF-DETR (single class only)
    # CLIP classifies RF-DETR detections into these 6 types
    CLASSES = [
        'button',      # Buttons, FAB, chips, switches
        'input',       # Text fields, search bars
        'text',        # Labels, titles, paragraphs, descriptions
        'image',       # Images, icons, avatars, illustrations
        'list_item',   # List items, cards, tiles
        'navigation'   # Bottom nav, tabs, app bars, menus
    ]
    
    # Default box color (BGR format for OpenCV)
    BOX_COLOR = (0, 255, 0)  # Green
    
    def __init__(self, model_path: str = "model.pth", enable_ocr: bool = True, enable_blip: bool = True, enable_clip: bool = True):
        """
        Initialize the Detection Service
        
        Args:
            model_path: Path to the RF-DETR model weights
            enable_ocr: Whether to enable OCR for text extraction
            enable_blip: Whether to enable BLIP for icon description
            enable_clip: Whether to enable CLIP for UI element classification
        """
        self.model_path = model_path
        self.enable_ocr = enable_ocr
        self.enable_blip = enable_blip
        self.enable_clip = enable_clip
        
        self.model = None
        self.ocr_reader = None
        self.blip_processor = None
        self.blip_model = None
        self.clip_processor = None
        self.clip_model = None
        
        # Load the detection model immediately
        self._load_detection_model()
    
    def _load_detection_model(self):
        """Load RF-DETR model (single-class UI element detector)"""
        if self.model is None:
            print("Loading RF-DETR model...")
            kwargs = {"pretrain_weights": self.model_path}
            custom_resolution = os.getenv("RFDETR_RESOLUTION")
            if custom_resolution:
                try:
                    kwargs["resolution"] = int(custom_resolution)
                    print(f"Using custom RF-DETR resolution: {kwargs['resolution']}")
                except ValueError:
                    print(f"Warning: invalid RFDETR_RESOLUTION '{custom_resolution}'. Falling back to model default.")
            else:
                kwargs["resolution"] = 1600  # Default tuned for CU-1 deployment

            self.model = RFDETRMedium(**kwargs)
            print("RF-DETR model loaded successfully!")
    
    def _load_ocr(self):
        """Load EasyOCR reader for text extraction"""
        if self.enable_ocr and self.ocr_reader is None:
            print("Loading OCR reader...")
            self.ocr_reader = easyocr.Reader(['en', 'fr'], gpu=torch.cuda.is_available())
            print("OCR reader loaded successfully!")
    
    def _load_blip(self):
        """Load BLIP model for image captioning"""
        if self.enable_blip and (self.blip_processor is None or self.blip_model is None):
            print("Loading BLIP model for icon description...")
            self.blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
            # Use safetensors format to avoid torch.load vulnerability (CVE-2025-32434)
            self.blip_model = BlipForConditionalGeneration.from_pretrained(
                "Salesforce/blip-image-captioning-base",
                use_safetensors=True
            )
            if torch.cuda.is_available():
                self.blip_model = self.blip_model.to("cuda")
            print("BLIP model loaded successfully!")
    
    def _load_clip(self):
        """Load CLIP model for UI element classification"""
        if self.enable_clip and (self.clip_processor is None or self.clip_model is None):
            print("Loading CLIP model for UI element classification...")
            self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
            # Use safetensors format to avoid torch.load vulnerability (CVE-2025-32434)
            self.clip_model = CLIPModel.from_pretrained(
                "openai/clip-vit-base-patch32",
                use_safetensors=True
            )
            if torch.cuda.is_available():
                self.clip_model = self.clip_model.to("cuda")
            print("CLIP model loaded successfully!")
    
    def _classify_with_clip(self, cropped_img: np.ndarray) -> int:
        """
        Classify UI element using CLIP
        
        Args:
            cropped_img: Cropped numpy array of the UI element
            
        Returns:
            Predicted class_id (0-5 corresponding to CLASSES)
        """
        if cropped_img.size == 0:
            return 0  # Default to first class
        
        if not self.enable_clip:
            return 0  # No classification, return default
        
        self._load_clip()
        
        try:
            # Convert numpy array to PIL Image
            pil_img = Image.fromarray(cropped_img)
            
            # Create text prompts for each class - Optimized for mobile UI
            text_prompts = [
                "a mobile app button or interactive element",
                "a text input field or search bar in a mobile app",
                "text label, heading, or paragraph in a mobile app",
                "an image, icon, or avatar in a mobile app",
                "a list item, card, or tile in a mobile app",
                "a navigation bar, tab, or menu in a mobile app"
            ]
            
            # Process with CLIP
            inputs = self.clip_processor(
                text=text_prompts,
                images=pil_img,
                return_tensors="pt",
                padding=True
            )
            
            if torch.cuda.is_available():
                inputs = {k: v.to("cuda") for k, v in inputs.items()}
            
            # Get predictions
            outputs = self.clip_model(**inputs)
            logits_per_image = outputs.logits_per_image
            probs = logits_per_image.softmax(dim=1)
            
            # Get the class with highest probability
            predicted_class_id = probs.argmax().item()
            
            return predicted_class_id
            
        except Exception as clip_error:
            print(f"CLIP classification error: {clip_error}")
            return 0  # Fallback to default class

    def _extract_text(self, cropped_img: np.ndarray) -> str:
        """Extract plain text from a cropped region using OCR (no BLIP)."""
        if not self.enable_ocr or cropped_img.size == 0:
            return ""
        self._load_ocr()
        try:
            ocr_results = self.ocr_reader.readtext(cropped_img, detail=0)
            return " ".join(ocr_results).strip()
        except Exception as ocr_error:
            print(f"OCR error: {ocr_error}")
            return ""

    def _describe_with_blip(self, cropped_img: np.ndarray) -> str:
        """Generate a visual description using BLIP for a cropped region."""
        if not self.enable_blip or cropped_img.size == 0:
            return ""
        self._load_blip()
        try:
            pil_img = Image.fromarray(cropped_img)
            inputs = self.blip_processor(pil_img, return_tensors="pt")
            if torch.cuda.is_available():
                inputs = {k: v.to("cuda") for k, v in inputs.items()}
            out = self.blip_model.generate(**inputs, max_length=50)
            return self.blip_processor.decode(out[0], skip_special_tokens=True)
        except Exception as blip_error:
            print(f"BLIP error: {blip_error}")
            return ""

    @staticmethod
    def _iou(box_a: Tuple[int, int, int, int], box_b: Tuple[int, int, int, int]) -> float:
        """Calculate Intersection over Union between two boxes"""
        xA = max(box_a[0], box_b[0])
        yA = max(box_a[1], box_b[1])
        xB = min(box_a[2], box_b[2])
        yB = min(box_a[3], box_b[3])
        inter_w = max(0, xB - xA)
        inter_h = max(0, yB - yA)
        inter_area = inter_w * inter_h
        if inter_area == 0:
            return 0.0
        box_a_area = max(0, (box_a[2] - box_a[0])) * max(0, (box_a[3] - box_a[1]))
        box_b_area = max(0, (box_b[2] - box_b[0])) * max(0, (box_b[3] - box_b[1]))
        union = box_a_area + box_b_area - inter_area
        if union <= 0:
            return 0.0
        return inter_area / union

    @staticmethod
    def _box_center(box: Tuple[int, int, int, int]) -> Tuple[float, float]:
        """Calculate the center point of a bounding box"""
        x1, y1, x2, y2 = box
        return (x1 + x2) / 2.0, (y1 + y2) / 2.0

    @torch.inference_mode()
    def analyze(
        self,
        image: Union[str, Path, np.ndarray, Image.Image],
        confidence_threshold: float = 0.35,
        extract_text: bool = True,
        use_clip: bool = True,
        use_blip: bool = False,
        merge_global_ocr: bool = True,
        blip_scope: str = "icons",
        preprocess: bool = False,
        preprocess_preset: str = "standard",
        preprocess_mode: str = "rfdetr"
    ) -> Dict:
        """
        Run a single-pass analysis: detection, optional CLIP classification, OCR, optional BLIP,
        and optional global OCR merge into nearest detection.
        
        PIPELINE:
        0. Optional preprocessing (normalize colors, contrast, denoise)
        1. RF-DETR detects UI elements (single class - just bounding boxes)
        2. CLIP classifies each detection into 6 types (if use_clip=True)
        3. OCR extracts text from each detection (if extract_text=True)
        4. BLIP generates descriptions for icons (if use_blip=True)
        5. Global OCR merge attaches stray text to nearest detections (if merge_global_ocr=True)

        Args:
            image: Input image (path, PIL Image, or numpy array)
            confidence_threshold: Minimum confidence for RF-DETR detections
            extract_text: Whether to run OCR on detections
            use_clip: Whether to classify detections with CLIP
            use_blip: Whether to generate BLIP descriptions
            merge_global_ocr: Whether to run global OCR and merge results
            blip_scope: "icons" (only image/button) or "all" (all elements)
            preprocess: Enable image preprocessing (recommended for cross-device consistency)
            preprocess_mode: Preprocessing mode - 'rfdetr' (optimized for RF-DETR) or 'generic' (for CLIP/OCR)
            preprocess_preset: Preprocessing preset - depends on mode:
                               - rfdetr mode: 'gentle', 'standard', 'aggressive_denoise', 'color_only'
                               - generic mode: 'standard', 'aggressive', 'minimal', 'ocr_optimized'

        Returns:
            Dict with keys:
                - detections: List of {box, confidence, class_id, class_name, text, description}
                - image_size: {width, height}
                - preprocessed: Whether preprocessing was applied
        """
        # Load image
        img_array = load_image(image)
        
        # Optional preprocessing for cross-device consistency
        preprocessed = False
        preprocessing_info = {}
        if preprocess:
            try:
                if preprocess_mode == "rfdetr":
                    # RF-DETR optimized preprocessing (preserves ImageNet normalization)
                    img_array = preprocess_for_rfdetr(img_array, preset=preprocess_preset)
                    preprocessed = True
                    preprocessing_info = {
                        "mode": "rfdetr",
                        "preset": preprocess_preset,
                        "description": "RF-DETR optimized (preserves ImageNet normalization)"
                    }
                elif preprocess_mode == "generic":
                    # Generic preprocessing (for CLIP/OCR optimization)
                    img_array = preprocess_screenshot(img_array, preset=preprocess_preset)
                    preprocessed = True
                    preprocessing_info = {
                        "mode": "generic",
                        "preset": preprocess_preset,
                        "description": "Generic preprocessing (CLIP/OCR optimized)"
                    }
                else:
                    print(f"Warning: Unknown preprocess_mode '{preprocess_mode}'. Using 'rfdetr'.")
                    img_array = preprocess_for_rfdetr(img_array, preset="standard")
                    preprocessed = True
                    preprocessing_info = {
                        "mode": "rfdetr",
                        "preset": "standard",
                        "description": "RF-DETR optimized (fallback)"
                    }
            except Exception as e:
                print(f"Warning: Preprocessing failed: {e}. Continuing with original image.")
                preprocessed = False
                preprocessing_info = {"error": str(e)}
        height, width = img_array.shape[:2]

        # RF-DETR Detection: Detects generic UI elements (SINGLE CLASS ONLY)
        det = self.model.predict(img_array, threshold=confidence_threshold)
        boxes = det.xyxy.tolist()
        scores = det.confidence.tolist()

        detections: List[Dict] = []
        for box, score in zip(boxes, scores):
            x1, y1, x2, y2 = map(int, box)
            cropped = img_array[y1:y2, x1:x2]

            # CLIP Classification: Classify RF-DETR detection into one of 6 types
            if use_clip and self.enable_clip:
                predicted_class_id = self._classify_with_clip(cropped)
                class_name = self.CLASSES[predicted_class_id] if 0 <= predicted_class_id < len(self.CLASSES) else "unknown"
            else:
                predicted_class_id = None
                class_name = ""

            # OCR text extraction per detection
            text = self._extract_text(cropped) if extract_text and self.enable_ocr else ""

            # BLIP description per detection (keep separate from text)
            description = ""
            if use_blip and self.enable_blip and (
                blip_scope == "all" or class_name in {"image", "button"}
            ):
                description = self._describe_with_blip(cropped)

            detections.append({
                "box": {"x1": float(x1), "y1": float(y1), "x2": float(x2), "y2": float(y2)},
                "confidence": float(score),
                "class_id": predicted_class_id,
                "class_name": class_name,
                "text": text,
                "description": description,
            })

        # Optional global OCR merge: attach stray OCR to nearest detection
        if merge_global_ocr and extract_text and self.enable_ocr:
            try:
                self._load_ocr()
                # detail=1 returns [ [ (x,y)...4 points ], text, conf ]
                global_ocr = self.ocr_reader.readtext(img_array, detail=1)
                # Precompute detection boxes as tuples
                det_boxes: List[Tuple[int, int, int, int]] = []
                for d in detections:
                    b = d["box"]
                    det_boxes.append((int(b["x1"]), int(b["y1"]), int(b["x2"]), int(b["y2"])) )

                for entry in global_ocr:
                    if not isinstance(entry, (list, tuple)) or len(entry) < 2:
                        continue
                    quad = entry[0]
                    text = entry[1] if isinstance(entry[1], str) else ""
                    if not text:
                        continue
                    # Convert quadrilateral to bounding box
                    xs = [p[0] for p in quad]
                    ys = [p[1] for p in quad]
                    obox = (int(min(xs)), int(min(ys)), int(max(xs)), int(max(ys)))

                    # Overlap with existing detections (IoU >= 0.1) → attach to best-overlap detection
                    overlaps = [self._iou(obox, db) for db in det_boxes]
                    if overlaps:
                        max_iou = max(overlaps)
                        if max_iou >= 0.1:
                            best_overlap_idx = int(np.argmax(np.array(overlaps)))
                            existing = detections[best_overlap_idx]["text"].strip()
                            if text not in existing:
                                detections[best_overlap_idx]["text"] = (
                                    existing + (" " if existing else "") + text
                                ).strip()
                            # Attached to overlapping detection; proceed to next OCR entry
                            continue

                    # No sufficient overlap → find nearest detection by center distance
                    ox, oy = self._box_center(obox)
                    best_idx = -1
                    best_dist = float("inf")
                    for idx, dbox in enumerate(det_boxes):
                        cx, cy = self._box_center(dbox)
                        dx = cx - ox
                        dy = cy - oy
                        dist2 = dx * dx + dy * dy
                        if dist2 < best_dist:
                            best_dist = dist2
                            best_idx = idx
                    if best_idx >= 0:
                        # Conservative distance threshold: within 0.3 of detection diagonal
                        bx1, by1, bx2, by2 = det_boxes[best_idx]
                        bw = max(1, bx2 - bx1)
                        bh = max(1, by2 - by1)
                        diag2 = bw * bw + bh * bh
                        if best_dist <= 0.09 * diag2:  # (0.3 * diag)^2
                            existing = detections[best_idx]["text"].strip()
                            if text not in existing:
                                detections[best_idx]["text"] = (
                                    existing + (" " if existing else "") + text
                                ).strip()
                            continue

                    # Not overlapping or near any detection → create a new OCR-only detection
                    new_det = {
                        "box": {
                            "x1": float(obox[0]),
                            "y1": float(obox[1]),
                            "x2": float(obox[2]),
                            "y2": float(obox[3]),
                        },
                        "confidence": float(entry[2]) if len(entry) > 2 and entry[2] is not None else 1.0,
                        "class_id": None,
                        "class_name": "",
                        "text": text.strip(),
                        "description": "",
                    }
                    detections.append(new_det)
                    det_boxes.append(obox)
            except Exception as e:
                print(f"Global OCR merge error: {e}")

        return {
            "detections": detections,
            "image_size": {"width": int(width), "height": int(height)},
            "preprocessed": preprocessed,
            "preprocessing_info": preprocessing_info if preprocessed else None
        }
    
    
    def _draw_detections(
        self,
        image: np.ndarray,
        boxes: List[List[float]],
        scores: List[float],
        classes: List[int],
        contents: Optional[List[str]] = None,
        thickness: int = 3,
        font_scale: float = 0.5
    ) -> np.ndarray:
        """Draw detection boxes and labels on image"""
        img_with_boxes = image.copy()

        for idx, (box, score, cls_id) in enumerate(zip(boxes, scores, classes)):
            x1, y1, x2, y2 = map(int, box)

            # Draw rectangle
            cv2.rectangle(img_with_boxes, (x1, y1), (x2, y2), self.BOX_COLOR, thickness)

            # Prepare label with confidence score
            label = f"{score:.2f}"
            
            # Add content if available
            content = ""
            if contents and idx < len(contents) and contents[idx]:
                content = contents[idx]
                # Truncate long content for display
                if len(content) > 40:
                    content = content[:37] + "..."

            # Calculate label size and position
            (label_width, label_height), baseline = cv2.getTextSize(
                label, cv2.FONT_HERSHEY_SIMPLEX, font_scale, thickness=2
            )

            # Draw label background
            label_y = max(y1 - 10, label_height + 10)
            cv2.rectangle(
                img_with_boxes,
                (x1, label_y - label_height - baseline - 5),
                (x1 + label_width + 5, label_y + baseline - 5),
                self.BOX_COLOR,
                -1
            )

            # Draw label text (confidence score)
            cv2.putText(
                img_with_boxes,
                label,
                (x1 + 2, label_y - baseline - 5),
                cv2.FONT_HERSHEY_SIMPLEX,
                font_scale,
                (255, 255, 255),
                thickness=2
            )
            
            # Draw content text below the box if available
            if content:
                content_font_scale = font_scale * 0.8
                (content_width, content_height), content_baseline = cv2.getTextSize(
                    content, cv2.FONT_HERSHEY_SIMPLEX, content_font_scale, thickness=1
                )
                
                # Position content below the bottom of the box
                content_y = min(y2 + content_height + 15, img_with_boxes.shape[0] - 5)
                
                # Draw content background
                cv2.rectangle(
                    img_with_boxes,
                    (x1, content_y - content_height - content_baseline - 3),
                    (x1 + content_width + 5, content_y + content_baseline),
                    (0, 180, 0),  # Slightly darker green
                    -1
                )
                
                # Draw content text
                cv2.putText(
                    img_with_boxes,
                    content,
                    (x1 + 2, content_y - content_baseline - 3),
                    cv2.FONT_HERSHEY_SIMPLEX,
                    content_font_scale,
                    (255, 255, 255),
                    thickness=1
                )

        return img_with_boxes

    @torch.inference_mode()
    def get_prediction_image(
        self,
        image: Union[str, Path, np.ndarray, Image.Image],
        confidence_threshold: float = 0.35,
        extract_content: bool = True,
        thickness: int = 3,
        font_scale: float = 0.5,
        return_format: str = "pil",
        analysis: Optional[Dict] = None
    ) -> Union[Image.Image, np.ndarray]:
        """
        Get annotated image with detection boxes drawn
        
        Args:
            image: Input image (path, PIL Image, or numpy array)
            confidence_threshold: Minimum confidence score for detections (0.0-1.0)
            extract_content: Whether to extract and display text content or icon descriptions
            thickness: Thickness of bounding box lines
            font_scale: Font scale for labels
            return_format: Return format - "pil" for PIL Image or "numpy" for numpy array
            analysis: Pre-computed analysis results (optional, for performance)
            
        Returns:
            Annotated image as PIL Image or numpy array (RGB)
        """
        # Load image
        img_array = load_image(image)

        if analysis is None:
            analysis = self.analyze(
                image,
                confidence_threshold=confidence_threshold,
                extract_text=extract_content,
                use_clip=self.enable_clip,
                use_blip=self.enable_blip,
                merge_global_ocr=True
            )
        boxes = []
        scores = []
        class_ids = []
        contents = []
        for det in analysis["detections"]:
            b = det["box"]
            boxes.append([b["x1"], b["y1"], b["x2"], b["y2"]])
            scores.append(det["confidence"])
            class_ids.append(det["class_id"] if det.get("class_id") is not None else 0)
            if extract_content:
                text = det.get("text") or ""
                desc = det.get("description") or ""
                contents.append(text if text else (f"[Icon: {desc}]" if desc else ""))
        
        # Convert to BGR for OpenCV
        img_bgr = cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR)
        
        # Draw detections
        annotated_img = self._draw_detections(
            img_bgr, boxes, scores, class_ids,
            contents if extract_content else None,
            thickness, font_scale
        )
        
        # Convert back to RGB
        annotated_img_rgb = cv2.cvtColor(annotated_img, cv2.COLOR_BGR2RGB)
        
        # Return in requested format
        if return_format.lower() == "pil":
            return Image.fromarray(annotated_img_rgb)
        else:
            return annotated_img_rgb