Spaces:
Sleeping
Sleeping
| """ | |
| Detection Service - Core Business Logic | |
| This module contains the main DetectionService class that handles UI element detection. | |
| ARCHITECTURE: | |
| ------------- | |
| This service uses a multi-model pipeline: | |
| 1. RF-DETR (Detection Transformer) | |
| - Detects generic "UI elements" as a SINGLE CLASS | |
| - Provides bounding boxes and confidence scores | |
| - Does NOT distinguish between button, input, text, etc. | |
| 2. CLIP (OpenAI) | |
| - OPTIONAL multi-class classification | |
| - Takes RF-DETR detections and classifies them into 6 types: | |
| * button, input, text, image, list_item, navigation | |
| - Only runs if enable_clip=True | |
| 3. EasyOCR | |
| - Extracts text content from detected regions | |
| - Runs global OCR merge to catch text outside detection boxes | |
| 4. BLIP (Salesforce) | |
| - OPTIONAL visual description generation | |
| - Describes icons and images when text is not present | |
| - Only runs if enable_blip=True | |
| Usage: | |
| from detection.service import DetectionService | |
| service = DetectionService() | |
| results = service.analyze(image_path) | |
| """ | |
| import os | |
| os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' | |
| import torch | |
| import cv2 | |
| import numpy as np | |
| from PIL import Image | |
| from typing import Union, List, Dict, Tuple, Optional | |
| from pathlib import Path | |
| from rfdetr.detr import RFDETRMedium | |
| import easyocr | |
| from transformers import BlipProcessor, BlipForConditionalGeneration, CLIPProcessor, CLIPModel | |
| from detection.image_utils import load_image | |
| from detection.image_preprocessing import preprocess_screenshot, PRESETS | |
| from detection.rfdetr_preprocessing import preprocess_for_rfdetr, RFDETR_PRESETS | |
| class DetectionService: | |
| """ | |
| Detection Service for UI Element Detection | |
| Provides a complete pipeline for detecting and analyzing UI elements in screenshots. | |
| Uses RF-DETR for detection (single class), CLIP for classification (6 classes), | |
| OCR for text extraction, and BLIP for visual descriptions. | |
| """ | |
| # UI Element classes - Optimized for Mobile Apps | |
| # NOTE: These are NOT detected by RF-DETR (single class only) | |
| # CLIP classifies RF-DETR detections into these 6 types | |
| CLASSES = [ | |
| 'button', # Buttons, FAB, chips, switches | |
| 'input', # Text fields, search bars | |
| 'text', # Labels, titles, paragraphs, descriptions | |
| 'image', # Images, icons, avatars, illustrations | |
| 'list_item', # List items, cards, tiles | |
| 'navigation' # Bottom nav, tabs, app bars, menus | |
| ] | |
| # Default box color (BGR format for OpenCV) | |
| BOX_COLOR = (0, 255, 0) # Green | |
| def __init__(self, model_path: str = "model.pth", enable_ocr: bool = True, enable_blip: bool = True, enable_clip: bool = True): | |
| """ | |
| Initialize the Detection Service | |
| Args: | |
| model_path: Path to the RF-DETR model weights | |
| enable_ocr: Whether to enable OCR for text extraction | |
| enable_blip: Whether to enable BLIP for icon description | |
| enable_clip: Whether to enable CLIP for UI element classification | |
| """ | |
| self.model_path = model_path | |
| self.enable_ocr = enable_ocr | |
| self.enable_blip = enable_blip | |
| self.enable_clip = enable_clip | |
| self.model = None | |
| self.ocr_reader = None | |
| self.blip_processor = None | |
| self.blip_model = None | |
| self.clip_processor = None | |
| self.clip_model = None | |
| # Load the detection model immediately | |
| self._load_detection_model() | |
| def _load_detection_model(self): | |
| """Load RF-DETR model (single-class UI element detector)""" | |
| if self.model is None: | |
| print("Loading RF-DETR model...") | |
| kwargs = {"pretrain_weights": self.model_path} | |
| custom_resolution = os.getenv("RFDETR_RESOLUTION") | |
| if custom_resolution: | |
| try: | |
| kwargs["resolution"] = int(custom_resolution) | |
| print(f"Using custom RF-DETR resolution: {kwargs['resolution']}") | |
| except ValueError: | |
| print(f"Warning: invalid RFDETR_RESOLUTION '{custom_resolution}'. Falling back to model default.") | |
| else: | |
| kwargs["resolution"] = 1600 # Default tuned for CU-1 deployment | |
| self.model = RFDETRMedium(**kwargs) | |
| print("RF-DETR model loaded successfully!") | |
| def _load_ocr(self): | |
| """Load EasyOCR reader for text extraction""" | |
| if self.enable_ocr and self.ocr_reader is None: | |
| print("Loading OCR reader...") | |
| self.ocr_reader = easyocr.Reader(['en', 'fr'], gpu=torch.cuda.is_available()) | |
| print("OCR reader loaded successfully!") | |
| def _load_blip(self): | |
| """Load BLIP model for image captioning""" | |
| if self.enable_blip and (self.blip_processor is None or self.blip_model is None): | |
| print("Loading BLIP model for icon description...") | |
| self.blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") | |
| # Use safetensors format to avoid torch.load vulnerability (CVE-2025-32434) | |
| self.blip_model = BlipForConditionalGeneration.from_pretrained( | |
| "Salesforce/blip-image-captioning-base", | |
| use_safetensors=True | |
| ) | |
| if torch.cuda.is_available(): | |
| self.blip_model = self.blip_model.to("cuda") | |
| print("BLIP model loaded successfully!") | |
| def _load_clip(self): | |
| """Load CLIP model for UI element classification""" | |
| if self.enable_clip and (self.clip_processor is None or self.clip_model is None): | |
| print("Loading CLIP model for UI element classification...") | |
| self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
| # Use safetensors format to avoid torch.load vulnerability (CVE-2025-32434) | |
| self.clip_model = CLIPModel.from_pretrained( | |
| "openai/clip-vit-base-patch32", | |
| use_safetensors=True | |
| ) | |
| if torch.cuda.is_available(): | |
| self.clip_model = self.clip_model.to("cuda") | |
| print("CLIP model loaded successfully!") | |
| def _classify_with_clip(self, cropped_img: np.ndarray) -> int: | |
| """ | |
| Classify UI element using CLIP | |
| Args: | |
| cropped_img: Cropped numpy array of the UI element | |
| Returns: | |
| Predicted class_id (0-5 corresponding to CLASSES) | |
| """ | |
| if cropped_img.size == 0: | |
| return 0 # Default to first class | |
| if not self.enable_clip: | |
| return 0 # No classification, return default | |
| self._load_clip() | |
| try: | |
| # Convert numpy array to PIL Image | |
| pil_img = Image.fromarray(cropped_img) | |
| # Create text prompts for each class - Optimized for mobile UI | |
| text_prompts = [ | |
| "a mobile app button or interactive element", | |
| "a text input field or search bar in a mobile app", | |
| "text label, heading, or paragraph in a mobile app", | |
| "an image, icon, or avatar in a mobile app", | |
| "a list item, card, or tile in a mobile app", | |
| "a navigation bar, tab, or menu in a mobile app" | |
| ] | |
| # Process with CLIP | |
| inputs = self.clip_processor( | |
| text=text_prompts, | |
| images=pil_img, | |
| return_tensors="pt", | |
| padding=True | |
| ) | |
| if torch.cuda.is_available(): | |
| inputs = {k: v.to("cuda") for k, v in inputs.items()} | |
| # Get predictions | |
| outputs = self.clip_model(**inputs) | |
| logits_per_image = outputs.logits_per_image | |
| probs = logits_per_image.softmax(dim=1) | |
| # Get the class with highest probability | |
| predicted_class_id = probs.argmax().item() | |
| return predicted_class_id | |
| except Exception as clip_error: | |
| print(f"CLIP classification error: {clip_error}") | |
| return 0 # Fallback to default class | |
| def _extract_text(self, cropped_img: np.ndarray) -> str: | |
| """Extract plain text from a cropped region using OCR (no BLIP).""" | |
| if not self.enable_ocr or cropped_img.size == 0: | |
| return "" | |
| self._load_ocr() | |
| try: | |
| ocr_results = self.ocr_reader.readtext(cropped_img, detail=0) | |
| return " ".join(ocr_results).strip() | |
| except Exception as ocr_error: | |
| print(f"OCR error: {ocr_error}") | |
| return "" | |
| def _describe_with_blip(self, cropped_img: np.ndarray) -> str: | |
| """Generate a visual description using BLIP for a cropped region.""" | |
| if not self.enable_blip or cropped_img.size == 0: | |
| return "" | |
| self._load_blip() | |
| try: | |
| pil_img = Image.fromarray(cropped_img) | |
| inputs = self.blip_processor(pil_img, return_tensors="pt") | |
| if torch.cuda.is_available(): | |
| inputs = {k: v.to("cuda") for k, v in inputs.items()} | |
| out = self.blip_model.generate(**inputs, max_length=50) | |
| return self.blip_processor.decode(out[0], skip_special_tokens=True) | |
| except Exception as blip_error: | |
| print(f"BLIP error: {blip_error}") | |
| return "" | |
| def _iou(box_a: Tuple[int, int, int, int], box_b: Tuple[int, int, int, int]) -> float: | |
| """Calculate Intersection over Union between two boxes""" | |
| xA = max(box_a[0], box_b[0]) | |
| yA = max(box_a[1], box_b[1]) | |
| xB = min(box_a[2], box_b[2]) | |
| yB = min(box_a[3], box_b[3]) | |
| inter_w = max(0, xB - xA) | |
| inter_h = max(0, yB - yA) | |
| inter_area = inter_w * inter_h | |
| if inter_area == 0: | |
| return 0.0 | |
| box_a_area = max(0, (box_a[2] - box_a[0])) * max(0, (box_a[3] - box_a[1])) | |
| box_b_area = max(0, (box_b[2] - box_b[0])) * max(0, (box_b[3] - box_b[1])) | |
| union = box_a_area + box_b_area - inter_area | |
| if union <= 0: | |
| return 0.0 | |
| return inter_area / union | |
| def _box_center(box: Tuple[int, int, int, int]) -> Tuple[float, float]: | |
| """Calculate the center point of a bounding box""" | |
| x1, y1, x2, y2 = box | |
| return (x1 + x2) / 2.0, (y1 + y2) / 2.0 | |
| def analyze( | |
| self, | |
| image: Union[str, Path, np.ndarray, Image.Image], | |
| confidence_threshold: float = 0.35, | |
| extract_text: bool = True, | |
| use_clip: bool = True, | |
| use_blip: bool = False, | |
| merge_global_ocr: bool = True, | |
| blip_scope: str = "icons", | |
| preprocess: bool = False, | |
| preprocess_preset: str = "standard", | |
| preprocess_mode: str = "rfdetr" | |
| ) -> Dict: | |
| """ | |
| Run a single-pass analysis: detection, optional CLIP classification, OCR, optional BLIP, | |
| and optional global OCR merge into nearest detection. | |
| PIPELINE: | |
| 0. Optional preprocessing (normalize colors, contrast, denoise) | |
| 1. RF-DETR detects UI elements (single class - just bounding boxes) | |
| 2. CLIP classifies each detection into 6 types (if use_clip=True) | |
| 3. OCR extracts text from each detection (if extract_text=True) | |
| 4. BLIP generates descriptions for icons (if use_blip=True) | |
| 5. Global OCR merge attaches stray text to nearest detections (if merge_global_ocr=True) | |
| Args: | |
| image: Input image (path, PIL Image, or numpy array) | |
| confidence_threshold: Minimum confidence for RF-DETR detections | |
| extract_text: Whether to run OCR on detections | |
| use_clip: Whether to classify detections with CLIP | |
| use_blip: Whether to generate BLIP descriptions | |
| merge_global_ocr: Whether to run global OCR and merge results | |
| blip_scope: "icons" (only image/button) or "all" (all elements) | |
| preprocess: Enable image preprocessing (recommended for cross-device consistency) | |
| preprocess_mode: Preprocessing mode - 'rfdetr' (optimized for RF-DETR) or 'generic' (for CLIP/OCR) | |
| preprocess_preset: Preprocessing preset - depends on mode: | |
| - rfdetr mode: 'gentle', 'standard', 'aggressive_denoise', 'color_only' | |
| - generic mode: 'standard', 'aggressive', 'minimal', 'ocr_optimized' | |
| Returns: | |
| Dict with keys: | |
| - detections: List of {box, confidence, class_id, class_name, text, description} | |
| - image_size: {width, height} | |
| - preprocessed: Whether preprocessing was applied | |
| """ | |
| # Load image | |
| img_array = load_image(image) | |
| # Optional preprocessing for cross-device consistency | |
| preprocessed = False | |
| preprocessing_info = {} | |
| if preprocess: | |
| try: | |
| if preprocess_mode == "rfdetr": | |
| # RF-DETR optimized preprocessing (preserves ImageNet normalization) | |
| img_array = preprocess_for_rfdetr(img_array, preset=preprocess_preset) | |
| preprocessed = True | |
| preprocessing_info = { | |
| "mode": "rfdetr", | |
| "preset": preprocess_preset, | |
| "description": "RF-DETR optimized (preserves ImageNet normalization)" | |
| } | |
| elif preprocess_mode == "generic": | |
| # Generic preprocessing (for CLIP/OCR optimization) | |
| img_array = preprocess_screenshot(img_array, preset=preprocess_preset) | |
| preprocessed = True | |
| preprocessing_info = { | |
| "mode": "generic", | |
| "preset": preprocess_preset, | |
| "description": "Generic preprocessing (CLIP/OCR optimized)" | |
| } | |
| else: | |
| print(f"Warning: Unknown preprocess_mode '{preprocess_mode}'. Using 'rfdetr'.") | |
| img_array = preprocess_for_rfdetr(img_array, preset="standard") | |
| preprocessed = True | |
| preprocessing_info = { | |
| "mode": "rfdetr", | |
| "preset": "standard", | |
| "description": "RF-DETR optimized (fallback)" | |
| } | |
| except Exception as e: | |
| print(f"Warning: Preprocessing failed: {e}. Continuing with original image.") | |
| preprocessed = False | |
| preprocessing_info = {"error": str(e)} | |
| height, width = img_array.shape[:2] | |
| # RF-DETR Detection: Detects generic UI elements (SINGLE CLASS ONLY) | |
| det = self.model.predict(img_array, threshold=confidence_threshold) | |
| boxes = det.xyxy.tolist() | |
| scores = det.confidence.tolist() | |
| detections: List[Dict] = [] | |
| for box, score in zip(boxes, scores): | |
| x1, y1, x2, y2 = map(int, box) | |
| cropped = img_array[y1:y2, x1:x2] | |
| # CLIP Classification: Classify RF-DETR detection into one of 6 types | |
| if use_clip and self.enable_clip: | |
| predicted_class_id = self._classify_with_clip(cropped) | |
| class_name = self.CLASSES[predicted_class_id] if 0 <= predicted_class_id < len(self.CLASSES) else "unknown" | |
| else: | |
| predicted_class_id = None | |
| class_name = "" | |
| # OCR text extraction per detection | |
| text = self._extract_text(cropped) if extract_text and self.enable_ocr else "" | |
| # BLIP description per detection (keep separate from text) | |
| description = "" | |
| if use_blip and self.enable_blip and ( | |
| blip_scope == "all" or class_name in {"image", "button"} | |
| ): | |
| description = self._describe_with_blip(cropped) | |
| detections.append({ | |
| "box": {"x1": float(x1), "y1": float(y1), "x2": float(x2), "y2": float(y2)}, | |
| "confidence": float(score), | |
| "class_id": predicted_class_id, | |
| "class_name": class_name, | |
| "text": text, | |
| "description": description, | |
| }) | |
| # Optional global OCR merge: attach stray OCR to nearest detection | |
| if merge_global_ocr and extract_text and self.enable_ocr: | |
| try: | |
| self._load_ocr() | |
| # detail=1 returns [ [ (x,y)...4 points ], text, conf ] | |
| global_ocr = self.ocr_reader.readtext(img_array, detail=1) | |
| # Precompute detection boxes as tuples | |
| det_boxes: List[Tuple[int, int, int, int]] = [] | |
| for d in detections: | |
| b = d["box"] | |
| det_boxes.append((int(b["x1"]), int(b["y1"]), int(b["x2"]), int(b["y2"])) ) | |
| for entry in global_ocr: | |
| if not isinstance(entry, (list, tuple)) or len(entry) < 2: | |
| continue | |
| quad = entry[0] | |
| text = entry[1] if isinstance(entry[1], str) else "" | |
| if not text: | |
| continue | |
| # Convert quadrilateral to bounding box | |
| xs = [p[0] for p in quad] | |
| ys = [p[1] for p in quad] | |
| obox = (int(min(xs)), int(min(ys)), int(max(xs)), int(max(ys))) | |
| # Overlap with existing detections (IoU >= 0.1) → attach to best-overlap detection | |
| overlaps = [self._iou(obox, db) for db in det_boxes] | |
| if overlaps: | |
| max_iou = max(overlaps) | |
| if max_iou >= 0.1: | |
| best_overlap_idx = int(np.argmax(np.array(overlaps))) | |
| existing = detections[best_overlap_idx]["text"].strip() | |
| if text not in existing: | |
| detections[best_overlap_idx]["text"] = ( | |
| existing + (" " if existing else "") + text | |
| ).strip() | |
| # Attached to overlapping detection; proceed to next OCR entry | |
| continue | |
| # No sufficient overlap → find nearest detection by center distance | |
| ox, oy = self._box_center(obox) | |
| best_idx = -1 | |
| best_dist = float("inf") | |
| for idx, dbox in enumerate(det_boxes): | |
| cx, cy = self._box_center(dbox) | |
| dx = cx - ox | |
| dy = cy - oy | |
| dist2 = dx * dx + dy * dy | |
| if dist2 < best_dist: | |
| best_dist = dist2 | |
| best_idx = idx | |
| if best_idx >= 0: | |
| # Conservative distance threshold: within 0.3 of detection diagonal | |
| bx1, by1, bx2, by2 = det_boxes[best_idx] | |
| bw = max(1, bx2 - bx1) | |
| bh = max(1, by2 - by1) | |
| diag2 = bw * bw + bh * bh | |
| if best_dist <= 0.09 * diag2: # (0.3 * diag)^2 | |
| existing = detections[best_idx]["text"].strip() | |
| if text not in existing: | |
| detections[best_idx]["text"] = ( | |
| existing + (" " if existing else "") + text | |
| ).strip() | |
| continue | |
| # Not overlapping or near any detection → create a new OCR-only detection | |
| new_det = { | |
| "box": { | |
| "x1": float(obox[0]), | |
| "y1": float(obox[1]), | |
| "x2": float(obox[2]), | |
| "y2": float(obox[3]), | |
| }, | |
| "confidence": float(entry[2]) if len(entry) > 2 and entry[2] is not None else 1.0, | |
| "class_id": None, | |
| "class_name": "", | |
| "text": text.strip(), | |
| "description": "", | |
| } | |
| detections.append(new_det) | |
| det_boxes.append(obox) | |
| except Exception as e: | |
| print(f"Global OCR merge error: {e}") | |
| return { | |
| "detections": detections, | |
| "image_size": {"width": int(width), "height": int(height)}, | |
| "preprocessed": preprocessed, | |
| "preprocessing_info": preprocessing_info if preprocessed else None | |
| } | |
| def _draw_detections( | |
| self, | |
| image: np.ndarray, | |
| boxes: List[List[float]], | |
| scores: List[float], | |
| classes: List[int], | |
| contents: Optional[List[str]] = None, | |
| thickness: int = 3, | |
| font_scale: float = 0.5 | |
| ) -> np.ndarray: | |
| """Draw detection boxes and labels on image""" | |
| img_with_boxes = image.copy() | |
| for idx, (box, score, cls_id) in enumerate(zip(boxes, scores, classes)): | |
| x1, y1, x2, y2 = map(int, box) | |
| # Draw rectangle | |
| cv2.rectangle(img_with_boxes, (x1, y1), (x2, y2), self.BOX_COLOR, thickness) | |
| # Prepare label with confidence score | |
| label = f"{score:.2f}" | |
| # Add content if available | |
| content = "" | |
| if contents and idx < len(contents) and contents[idx]: | |
| content = contents[idx] | |
| # Truncate long content for display | |
| if len(content) > 40: | |
| content = content[:37] + "..." | |
| # Calculate label size and position | |
| (label_width, label_height), baseline = cv2.getTextSize( | |
| label, cv2.FONT_HERSHEY_SIMPLEX, font_scale, thickness=2 | |
| ) | |
| # Draw label background | |
| label_y = max(y1 - 10, label_height + 10) | |
| cv2.rectangle( | |
| img_with_boxes, | |
| (x1, label_y - label_height - baseline - 5), | |
| (x1 + label_width + 5, label_y + baseline - 5), | |
| self.BOX_COLOR, | |
| -1 | |
| ) | |
| # Draw label text (confidence score) | |
| cv2.putText( | |
| img_with_boxes, | |
| label, | |
| (x1 + 2, label_y - baseline - 5), | |
| cv2.FONT_HERSHEY_SIMPLEX, | |
| font_scale, | |
| (255, 255, 255), | |
| thickness=2 | |
| ) | |
| # Draw content text below the box if available | |
| if content: | |
| content_font_scale = font_scale * 0.8 | |
| (content_width, content_height), content_baseline = cv2.getTextSize( | |
| content, cv2.FONT_HERSHEY_SIMPLEX, content_font_scale, thickness=1 | |
| ) | |
| # Position content below the bottom of the box | |
| content_y = min(y2 + content_height + 15, img_with_boxes.shape[0] - 5) | |
| # Draw content background | |
| cv2.rectangle( | |
| img_with_boxes, | |
| (x1, content_y - content_height - content_baseline - 3), | |
| (x1 + content_width + 5, content_y + content_baseline), | |
| (0, 180, 0), # Slightly darker green | |
| -1 | |
| ) | |
| # Draw content text | |
| cv2.putText( | |
| img_with_boxes, | |
| content, | |
| (x1 + 2, content_y - content_baseline - 3), | |
| cv2.FONT_HERSHEY_SIMPLEX, | |
| content_font_scale, | |
| (255, 255, 255), | |
| thickness=1 | |
| ) | |
| return img_with_boxes | |
| def get_prediction_image( | |
| self, | |
| image: Union[str, Path, np.ndarray, Image.Image], | |
| confidence_threshold: float = 0.35, | |
| extract_content: bool = True, | |
| thickness: int = 3, | |
| font_scale: float = 0.5, | |
| return_format: str = "pil", | |
| analysis: Optional[Dict] = None | |
| ) -> Union[Image.Image, np.ndarray]: | |
| """ | |
| Get annotated image with detection boxes drawn | |
| Args: | |
| image: Input image (path, PIL Image, or numpy array) | |
| confidence_threshold: Minimum confidence score for detections (0.0-1.0) | |
| extract_content: Whether to extract and display text content or icon descriptions | |
| thickness: Thickness of bounding box lines | |
| font_scale: Font scale for labels | |
| return_format: Return format - "pil" for PIL Image or "numpy" for numpy array | |
| analysis: Pre-computed analysis results (optional, for performance) | |
| Returns: | |
| Annotated image as PIL Image or numpy array (RGB) | |
| """ | |
| # Load image | |
| img_array = load_image(image) | |
| if analysis is None: | |
| analysis = self.analyze( | |
| image, | |
| confidence_threshold=confidence_threshold, | |
| extract_text=extract_content, | |
| use_clip=self.enable_clip, | |
| use_blip=self.enable_blip, | |
| merge_global_ocr=True | |
| ) | |
| boxes = [] | |
| scores = [] | |
| class_ids = [] | |
| contents = [] | |
| for det in analysis["detections"]: | |
| b = det["box"] | |
| boxes.append([b["x1"], b["y1"], b["x2"], b["y2"]]) | |
| scores.append(det["confidence"]) | |
| class_ids.append(det["class_id"] if det.get("class_id") is not None else 0) | |
| if extract_content: | |
| text = det.get("text") or "" | |
| desc = det.get("description") or "" | |
| contents.append(text if text else (f"[Icon: {desc}]" if desc else "")) | |
| # Convert to BGR for OpenCV | |
| img_bgr = cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR) | |
| # Draw detections | |
| annotated_img = self._draw_detections( | |
| img_bgr, boxes, scores, class_ids, | |
| contents if extract_content else None, | |
| thickness, font_scale | |
| ) | |
| # Convert back to RGB | |
| annotated_img_rgb = cv2.cvtColor(annotated_img, cv2.COLOR_BGR2RGB) | |
| # Return in requested format | |
| if return_format.lower() == "pil": | |
| return Image.fromarray(annotated_img_rgb) | |
| else: | |
| return annotated_img_rgb | |