| import torch |
| import math |
| from PIL import Image |
| from typing import Dict, List, Tuple |
| from rapidfuzz import fuzz |
| from prompt_library_manager import PromptLibraryManager |
| from brand_detection_optimizer import BrandDetectionOptimizer |
|
|
| class BrandRecognitionManager: |
| """Multi-modal brand recognition with detailed prompts (Visual + Text)""" |
|
|
| def __init__(self, clip_manager, ocr_manager, prompt_library=None): |
| self.clip_manager = clip_manager |
| self.ocr_manager = ocr_manager |
| self.prompt_library = prompt_library |
| self.flat_brands = prompt_library.get_all_brands() |
|
|
| |
| self.optimizer = BrandDetectionOptimizer(clip_manager, ocr_manager, prompt_library) |
|
|
| print(f"✓ Brand Recognition Manager loaded with {len(self.flat_brands)} brands (with optimizer)") |
|
|
| def recognize_brand(self, image_region: Image.Image, full_image: Image.Image, |
| region_bbox: List[int] = None) -> List[Tuple[str, float, List[int]]]: |
| """Recognize brands using detailed context-aware prompts |
| |
| Args: |
| image_region: Cropped region containing potential brand |
| full_image: Full image for OCR |
| region_bbox: Bounding box [x1, y1, x2, y2] for visualization |
| |
| Returns: |
| List of (brand_name, confidence, bbox) tuples |
| """ |
|
|
| |
| region_context = self._classify_region_context(image_region) |
| print(f" [DEBUG] Region context classified as: {region_context}") |
|
|
| |
| brand_scores = {} |
|
|
| for brand_name, brand_info in self.flat_brands.items(): |
| |
| best_context = self._match_region_to_brand_context(region_context, brand_info['region_contexts']) |
|
|
| if best_context and best_context in brand_info['openclip_prompts']: |
| |
| prompts = brand_info['openclip_prompts'][best_context] |
| visual_scores = self.clip_manager.classify_zero_shot(image_region, prompts) |
|
|
| |
| avg_score = sum(visual_scores.values()) / len(visual_scores) if visual_scores else 0.0 |
| else: |
| |
| prompts = brand_info['strong_cues'][:5] |
| visual_scores = self.clip_manager.classify_zero_shot(image_region, prompts) |
| avg_score = sum(visual_scores.values()) / len(visual_scores) if visual_scores else 0.0 |
|
|
| brand_scores[brand_name] = avg_score |
|
|
| |
| brand_scores = self._multi_scale_visual_matching(image_region, brand_scores) |
|
|
| |
| ocr_results = self.ocr_manager.extract_text(full_image, use_brand_preprocessing=True) |
| text_matches = self._fuzzy_text_matching(ocr_results) |
|
|
| print(f" [DEBUG] OCR found {len(ocr_results)} text regions") |
| if text_matches: |
| print(f" [DEBUG] OCR brand matches: {text_matches}") |
|
|
| |
| final_scores = {} |
| for brand_name in self.flat_brands.keys(): |
| visual_score = brand_scores.get(brand_name, 0.0) |
| text_score, ocr_conf = text_matches.get(brand_name, (0.0, 0.0)) |
|
|
| |
| visual_weight, text_weight, ocr_weight = self._calculate_adaptive_weights( |
| brand_name, visual_score, text_score, ocr_conf |
| ) |
|
|
| |
| final_score = ( |
| visual_weight * self._scale_visual(visual_score) + |
| text_weight * text_score + |
| ocr_weight * ocr_conf |
| ) |
| final_scores[brand_name] = final_score |
|
|
| sorted_scores = sorted(final_scores.items(), key=lambda x: x[1], reverse=True)[:5] |
| print(f" [DEBUG] Top 5 brand scores:") |
| for brand, score in sorted_scores: |
| print(f" {brand}: {score:.4f} (visual={brand_scores.get(brand, 0):.4f}, text={text_matches.get(brand, (0, 0))[0]:.4f})") |
|
|
| |
| confident_brands = [] |
| for brand_name, score in final_scores.items(): |
| if score > 0.10: |
| confident_brands.append((brand_name, score, region_bbox)) |
| print(f" [DEBUG] ✓ Brand detected: {brand_name} (confidence: {score:.4f})") |
|
|
| confident_brands.sort(key=lambda x: x[1], reverse=True) |
|
|
| if not confident_brands: |
| print(f" [DEBUG] ✗ No brands passed threshold 0.10") |
|
|
| return confident_brands |
|
|
| def _classify_region_context(self, image_region: Image.Image) -> str: |
| """Classify what type of region this is (bag_panel, shoe_side, etc.)""" |
| context_labels = [ |
| 'bag panel with pattern', |
| 'luggage surface with branding', |
| 'luxury trunk with monogram pattern', |
| 'vintage travel trunk with hardware', |
| 'shoe side view', |
| 'device back cover', |
| 'apparel chest area', |
| 'belt buckle', |
| 'storefront sign', |
| 'product tag or label', |
| 'wallet surface', |
| 'perfume bottle', |
| 'watch dial or face', |
| 'car front grille', |
| 'laptop lid' |
| ] |
|
|
| scores = self.clip_manager.classify_zero_shot(image_region, context_labels) |
|
|
| |
| context_mapping = { |
| 'bag panel with pattern': 'bag_panel', |
| 'luggage surface with branding': 'luggage_surface', |
| 'luxury trunk with monogram pattern': 'trunk_body', |
| 'vintage travel trunk with hardware': 'trunk_body', |
| 'shoe side view': 'shoe_side', |
| 'device back cover': 'device_back', |
| 'apparel chest area': 'apparel_chest', |
| 'belt buckle': 'belt_buckle', |
| 'storefront sign': 'storefront', |
| 'product tag or label': 'product_tag', |
| 'wallet surface': 'wallet', |
| 'perfume bottle': 'perfume_bottle', |
| 'watch dial or face': 'watch_dial', |
| 'car front grille': 'car_front', |
| 'laptop lid': 'laptop_lid' |
| } |
|
|
| top_context = max(scores.items(), key=lambda x: x[1])[0] |
| return context_mapping.get(top_context, 'unknown') |
|
|
| def _match_region_to_brand_context(self, region_context: str, brand_contexts: List[str]) -> str: |
| """Match detected region context to brand's available contexts""" |
| if region_context in brand_contexts: |
| return region_context |
| |
| for brand_context in brand_contexts: |
| if region_context.split('_')[0] in brand_context: |
| return brand_context |
| return None |
|
|
| def _fuzzy_text_matching(self, ocr_results: List[Dict]) -> Dict[str, Tuple[float, float]]: |
| """Fuzzy text matching using brand aliases (optimized for logo text)""" |
| matches = {} |
|
|
| for ocr_item in ocr_results: |
| text = ocr_item['text'] |
| conf = ocr_item['confidence'] |
|
|
| for brand_name, brand_info in self.flat_brands.items(): |
| |
| all_names = [brand_name] + brand_info.get('aliases', []) |
|
|
| for alias in all_names: |
| ratio = fuzz.ratio(text, alias) / 100.0 |
| if ratio > 0.70: |
| if brand_name not in matches or ratio > matches[brand_name][0]: |
| matches[brand_name] = (ratio, conf) |
|
|
| return matches |
|
|
| def _scale_visual(self, score: float) -> float: |
| """Scale visual score using sigmoid""" |
| return 1 / (1 + math.exp(-10 * (score - 0.5))) |
|
|
| def _calculate_adaptive_weights(self, brand_name: str, visual_score: float, |
| text_score: float, ocr_conf: float) -> tuple: |
| """ |
| Calculate adaptive weights based on brand characteristics and signal strengths |
| |
| Args: |
| brand_name: Name of the brand |
| visual_score: Visual similarity score |
| text_score: Text matching score |
| ocr_conf: OCR confidence |
| |
| Returns: |
| Tuple of (visual_weight, text_weight, ocr_weight) |
| """ |
| brand_info = self.prompt_library.get_brand_prompts(brand_name) |
|
|
| if not brand_info: |
| |
| return 0.50, 0.30, 0.20 |
|
|
| |
| if brand_info.get('visual_distinctive', False): |
| |
| visual_weight = 0.65 |
| text_weight = 0.20 |
| ocr_weight = 0.15 |
| elif brand_info.get('text_prominent', False): |
| |
| visual_weight = 0.30 |
| text_weight = 0.30 |
| ocr_weight = 0.40 |
| else: |
| |
| visual_weight = 0.50 |
| text_weight = 0.30 |
| ocr_weight = 0.20 |
|
|
| |
| |
| if visual_score > 0.7: |
| boost = 0.10 |
| visual_weight += boost |
| text_weight -= boost * 0.5 |
| ocr_weight -= boost * 0.5 |
|
|
| |
| if ocr_conf > 0.85: |
| boost = 0.10 |
| ocr_weight += boost |
| visual_weight -= boost * 0.6 |
| text_weight -= boost * 0.4 |
|
|
| |
| if text_score > 0.80: |
| boost = 0.08 |
| text_weight += boost |
| visual_weight -= boost * 0.5 |
| ocr_weight -= boost * 0.5 |
|
|
| |
| total = visual_weight + text_weight + ocr_weight |
| return visual_weight / total, text_weight / total, ocr_weight / total |
|
|
| def _multi_scale_visual_matching(self, image_region: Image.Image, |
| initial_scores: Dict[str, float]) -> Dict[str, float]: |
| """ |
| Apply multi-scale matching to improve robustness |
| |
| Args: |
| image_region: Image region to analyze |
| initial_scores: Initial brand scores from single-scale matching |
| |
| Returns: |
| Updated brand scores with multi-scale matching |
| """ |
| scales = [0.8, 1.0, 1.2] |
| multi_scale_scores = {brand: [] for brand in initial_scores.keys()} |
|
|
| for scale in scales: |
| |
| new_width = int(image_region.width * scale) |
| new_height = int(image_region.height * scale) |
|
|
| |
| if new_width < 50 or new_height < 50: |
| continue |
|
|
| try: |
| scaled_img = image_region.resize((new_width, new_height), Image.Resampling.LANCZOS) |
|
|
| |
| for brand_name, brand_info in self.flat_brands.items(): |
| |
| best_context = self._match_region_to_brand_context( |
| 'bag_panel', |
| brand_info.get('region_contexts', []) |
| ) |
|
|
| if best_context and best_context in brand_info.get('openclip_prompts', {}): |
| prompts = brand_info['openclip_prompts'][best_context] |
| visual_scores = self.clip_manager.classify_zero_shot(scaled_img, prompts) |
| avg_score = sum(visual_scores.values()) / len(visual_scores) if visual_scores else 0.0 |
| else: |
| prompts = brand_info.get('strong_cues', [])[:3] |
| visual_scores = self.clip_manager.classify_zero_shot(scaled_img, prompts) |
| avg_score = sum(visual_scores.values()) / len(visual_scores) if visual_scores else 0.0 |
|
|
| multi_scale_scores[brand_name].append(avg_score) |
|
|
| except Exception as e: |
| |
| continue |
|
|
| |
| final_scores = {} |
| for brand_name, scores in multi_scale_scores.items(): |
| if scores: |
| final_scores[brand_name] = max(scores) |
| else: |
| final_scores[brand_name] = initial_scores.get(brand_name, 0.0) |
|
|
| return final_scores |
|
|
| def scan_full_image_for_brands(self, full_image: Image.Image, |
| exclude_bboxes: List[List[int]] = None, |
| saliency_regions: List[Dict] = None) -> List[Tuple[str, float, List[int]]]: |
| """ |
| 智能全圖品牌掃描 - 性能優化版本 |
| 使用預篩選和智能區域選擇大幅減少檢測時間 |
| |
| Args: |
| full_image: PIL Image (full image) |
| exclude_bboxes: List of bboxes to exclude (already detected) |
| saliency_regions: Saliency detection results for smart region selection |
| |
| Returns: |
| List of (brand_name, confidence, bbox) tuples |
| """ |
| if exclude_bboxes is None: |
| exclude_bboxes = [] |
|
|
| detected_brands = {} |
| img_width, img_height = full_image.size |
|
|
| |
| likely_brands = self.optimizer.quick_brand_prescreening(full_image) |
| print(f" Quick prescreening found {len(likely_brands)} potential brands") |
|
|
| |
| regions_to_scan = self.optimizer.smart_region_selection(full_image, saliency_regions or []) |
| print(f" Scanning {len(regions_to_scan)} intelligent regions") |
|
|
| |
| for region_bbox in regions_to_scan: |
| x1, y1, x2, y2 = region_bbox |
|
|
| |
| if self._bbox_overlap(list(region_bbox), exclude_bboxes): |
| continue |
|
|
| |
| region = full_image.crop(region_bbox) |
|
|
| |
| for brand_name in likely_brands: |
| brand_info = self.flat_brands.get(brand_name) |
| if not brand_info: |
| continue |
|
|
| |
| strong_cues = brand_info.get('strong_cues', [])[:5] |
| if not strong_cues: |
| continue |
|
|
| visual_scores = self.clip_manager.classify_zero_shot(region, strong_cues) |
| avg_score = sum(visual_scores.values()) / len(visual_scores) if visual_scores else 0.0 |
|
|
| |
| ocr_results = self.ocr_manager.extract_text(full_image, use_brand_preprocessing=True) |
| boosted_score = self.optimizer.compute_brand_confidence_boost( |
| brand_name, ocr_results, avg_score |
| ) |
|
|
| |
| if boosted_score > 0.08: |
| |
| if brand_name not in detected_brands or boosted_score > detected_brands[brand_name][0]: |
| detected_brands[brand_name] = (boosted_score, list(region_bbox)) |
|
|
| |
| final_brands = [ |
| (brand_name, confidence, bbox) |
| for brand_name, (confidence, bbox) in detected_brands.items() |
| ] |
|
|
| |
| final_brands.sort(key=lambda x: x[1], reverse=True) |
|
|
| return final_brands[:5] |
|
|
| def _bbox_overlap(self, bbox1: List[int], bbox_list: List[List[int]]) -> bool: |
| """Check if bbox1 overlaps significantly with any bbox in bbox_list""" |
| if not bbox_list: |
| return False |
|
|
| x1_1, y1_1, x2_1, y2_1 = bbox1 |
|
|
| for bbox2 in bbox_list: |
| if bbox2 is None: |
| continue |
|
|
| x1_2, y1_2, x2_2, y2_2 = bbox2 |
|
|
| |
| x_left = max(x1_1, x1_2) |
| y_top = max(y1_1, y1_2) |
| x_right = min(x2_1, x2_2) |
| y_bottom = min(y2_1, y2_2) |
|
|
| if x_right < x_left or y_bottom < y_top: |
| continue |
|
|
| intersection_area = (x_right - x_left) * (y_bottom - y_top) |
| bbox1_area = (x2_1 - x1_1) * (y2_1 - y1_1) |
|
|
| |
| if intersection_area / bbox1_area > 0.3: |
| return True |
|
|
| return False |
|
|
| print("✓ BrandRecognitionManager (with full-image scan for commercial use) defined") |
|
|