import numpy as np import torch import torch.nn.functional as F import cv2 from typing import List, Dict, Tuple, Any, Optional from src.exceptions import ( BOMDetectorException, InvalidImageException, CancellationState, DetectionCancelledException ) from src.metrics import PerformanceTracker from src.io_validation import validate_inputs from src.preprocessing import ( synchronize_polarity, preprocess_for_matching, filter_informative_proposals ) from src.features import get_shared_feature_extractor from src.engines import multiscale_template_match, soft_nms, refine_bbox_local_search def generate_template_variants( template: np.ndarray, ) -> List[Tuple[np.ndarray, str]]: """Sinh 4 biến thể xoay mẫu nhằm bảo vệ vẽ bản không bị xoay tốn bộ nhớ.""" variants = [ (template.copy(), "R0"), (cv2.rotate(template, cv2.ROTATE_90_CLOCKWISE), "R90"), (cv2.rotate(template, cv2.ROTATE_180), "R180"), (cv2.rotate(template, cv2.ROTATE_90_COUNTERCLOCKWISE), "R270"), ] return variants class PatternDetector: """ Orchestrator điều phối toàn bộ luồng khớp và đo đạc hiệu suất. """ def __init__(self, device: str = "cpu") -> None: self.device = device self.drawing_raw: np.ndarray = None self.drawing_gray: np.ndarray = None self.templates_variants: List[Tuple[np.ndarray, str]] = [] self.tracker = PerformanceTracker() def clear(self) -> None: """Thu hồi triệt để bộ nhớ tránh leak RAM/VRAM.""" self.drawing_raw = None self.drawing_gray = None self.templates_variants = [] self.tracker = PerformanceTracker() if torch.cuda.is_available(): torch.cuda.empty_cache() def load_drawing(self, drawing_img: np.ndarray) -> None: """Đọc chuẩn hóa drawing.""" try: self.tracker.start_stage("load_and_normalize_drawing") if drawing_img is None or drawing_img.size == 0: raise InvalidImageException("Ảnh bản vẽ đầu vào trống.") self.drawing_raw = drawing_img.copy() if drawing_img.ndim == 3: self.drawing_gray = cv2.cvtColor(drawing_img, cv2.COLOR_BGR2GRAY) else: self.drawing_gray = drawing_img.copy() self.tracker.end_stage("load_and_normalize_drawing") except Exception as e: self.clear() if isinstance(e, BOMDetectorException): raise e raise InvalidImageException(f"Lỗi nạp bản vẽ: {str(e)}") def add_templates(self, templates: List[np.ndarray], with_rotation: bool = False) -> None: """Nạp và đồng bộ hóa phân cực template.""" try: self.tracker.start_stage("add_templates") if not templates: raise BOMDetectorException("Không có template.") self.templates_variants = [] for tmpl in templates: if tmpl is None or tmpl.size == 0: raise InvalidImageException("Ảnh mẫu trống.") tmpl_gray = cv2.cvtColor(tmpl, cv2.COLOR_BGR2GRAY) if tmpl.ndim == 3 else tmpl.copy() validate_inputs(self.drawing_gray, tmpl_gray) _, tmpl_sync = synchronize_polarity(self.drawing_gray, tmpl_gray) if with_rotation: self.templates_variants.extend(generate_template_variants(tmpl_sync)) else: self.templates_variants.append((tmpl_sync, "R0")) self.tracker.end_stage("add_templates") except Exception as e: self.clear() if isinstance(e, BOMDetectorException): raise e raise BOMDetectorException(f"Lỗi đăng ký template: {str(e)}") def detect( self, mode: str = "v3", confidence_threshold: float = 0.75, v1_threshold: float = 0.50, v2_threshold: float = 0.80, alpha: float = 0.30, iou_threshold: float = 0.30, enable_local_refine: bool = False, variance_std_threshold: float = 5.0, context_margin_pct: float = 0.15, extractor_type: str = "auto", cancellation_state: Optional[CancellationState] = None, ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]: """ Hàm suy luận trung tâm của PatternDetector hỗ trợ đo đạc chi tiết và dọn dẹp graceful. """ if cancellation_state is not None: cancellation_state.check() if self.drawing_gray is None: raise BOMDetectorException("Drawing chưa được nạp.") if not self.templates_variants: raise BOMDetectorException("Template chưa được đăng ký.") all_results = [] try: if cancellation_state is not None: cancellation_state.check() # 1. Hoist Drawing Polarity Synchronization self.tracker.start_stage("Prep_Drawing_Polarity") drawing_sync, _ = synchronize_polarity(self.drawing_gray, self.drawing_gray) self.tracker.end_stage("Prep_Drawing_Polarity") # 2. Hoist Drawing Edge Preprocessing self.tracker.start_stage("Prep_Drawing_Edges") drawing_edge = preprocess_for_matching(drawing_sync, method="dilated_edge") self.tracker.end_stage("Prep_Drawing_Edges") # 3. Precompute Template Edges for all variants using list aligned with self.templates_variants self.tracker.start_stage("Prep_Template_Edges") tmpl_edges = [ preprocess_for_matching(tmpl, method="dilated_edge") for tmpl, _ in self.templates_variants ] self.tracker.end_stage("Prep_Template_Edges") if cancellation_state is not None: cancellation_state.check() for idx, (tmpl, rotation_name) in enumerate(self.templates_variants): if cancellation_state is not None: cancellation_state.check() tmpl_sync = tmpl # Already synchronized in add_templates() tmpl_edge = tmpl_edges[idx] if mode == "v1": self.tracker.start_stage(f"V1_Prep_{rotation_name}") # Already precomputed! self.tracker.end_stage(f"V1_Prep_{rotation_name}") self.tracker.start_stage(f"V1_Matching_{rotation_name}") proposals = multiscale_template_match( drawing_edge, tmpl_edge, threshold=v1_threshold, cancellation_state=cancellation_state ) self.tracker.end_stage(f"V1_Matching_{rotation_name}") for p in proposals: all_results.append({ "bbox": p[:4], "confidence": float(p[4]), "rotation": rotation_name, "scale": float(p[5]), "variant_idx": idx }) elif mode == "v2": self.tracker.start_stage(f"V2_Candidate_Gen_{rotation_name}") proposals = multiscale_template_match( drawing_edge, tmpl_edge, threshold=0.35, cancellation_state=cancellation_state ) self.tracker.end_stage(f"V2_Candidate_Gen_{rotation_name}") if not proposals: continue # Coarse NMS to prune candidates before heavy CNN proposal_dicts = [{ "bbox": p[:4], "confidence": float(p[4]), "scale": float(p[5]) } for p in proposals] pruned_proposal_dicts = soft_nms( proposal_dicts, iou_threshold=0.5, score_threshold=0.35, method="gaussian" ) proposals = [ (pd["bbox"][0], pd["bbox"][1], pd["bbox"][2], pd["bbox"][3], pd["confidence"], pd["scale"]) for pd in pruned_proposal_dicts ] self.tracker.start_stage(f"V2_Blank_Filtering_{rotation_name}") proposals = filter_informative_proposals( proposals, drawing_sync, std_threshold=variance_std_threshold ) self.tracker.end_stage(f"V2_Blank_Filtering_{rotation_name}") if not proposals: continue if cancellation_state is not None: cancellation_state.check() self.tracker.start_stage(f"V2_CNN_Init_{rotation_name}") selected_backbone = extractor_type if extractor_type == "auto": th, tw = tmpl_sync.shape[:2] selected_backbone = "resnet18" if min(th, tw) < 56 else "dinov2" extractor = get_shared_feature_extractor(backbone=selected_backbone, device=self.device) self.tracker.end_stage(f"V2_CNN_Init_{rotation_name}") self.tracker.start_stage(f"V2_Batch_CNN_{rotation_name}") crops = [] H, W = drawing_sync.shape[:2] for p in proposals: x, y, bw, bh = p[0], p[1], p[2], p[3] margin_y = int(bh * context_margin_pct) margin_x = int(bw * context_margin_pct) x1 = max(0, x - margin_x) y1 = max(0, y - margin_y) x2 = min(W, x + bw + margin_x) y2 = min(H, y + bh + margin_y) crops.append(drawing_sync[y1:y2, x1:x2]) T_vec = extractor.extract(tmpl_sync) P_vecs = extractor.extract_batch(crops) T_vecs = T_vec.unsqueeze(0).expand(len(crops), -1) scores_v2 = F.cosine_similarity(P_vecs, T_vecs, dim=1) self.tracker.end_stage(f"V2_Batch_CNN_{rotation_name}") for i, p in enumerate(proposals): s_v2 = float(scores_v2[i].item()) if s_v2 >= v2_threshold: all_results.append({ "bbox": p[:4], "confidence": s_v2, "rotation": rotation_name, "scale": float(p[5]), "variant_idx": idx }) elif mode == "v3": self.tracker.start_stage(f"V3_Coarse_V1_{rotation_name}") proposals = multiscale_template_match( drawing_edge, tmpl_edge, threshold=v1_threshold, cancellation_state=cancellation_state ) self.tracker.end_stage(f"V3_Coarse_V1_{rotation_name}") if not proposals: continue # Coarse NMS to prune candidates before heavy CNN proposal_dicts = [{ "bbox": p[:4], "confidence": float(p[4]), "scale": float(p[5]) } for p in proposals] pruned_proposal_dicts = soft_nms( proposal_dicts, iou_threshold=0.5, score_threshold=v1_threshold, method="gaussian" ) proposals = [ (pd["bbox"][0], pd["bbox"][1], pd["bbox"][2], pd["bbox"][3], pd["confidence"], pd["scale"]) for pd in pruned_proposal_dicts ] self.tracker.start_stage(f"V3_Blank_Filtering_{rotation_name}") proposals = filter_informative_proposals( proposals, drawing_sync, std_threshold=variance_std_threshold ) self.tracker.end_stage(f"V3_Blank_Filtering_{rotation_name}") if not proposals: continue if cancellation_state is not None: cancellation_state.check() self.tracker.start_stage(f"V3_CNN_Init_{rotation_name}") selected_backbone = extractor_type if extractor_type == "auto": th, tw = tmpl_sync.shape[:2] selected_backbone = "resnet18" if min(th, tw) < 56 else "dinov2" extractor = get_shared_feature_extractor(backbone=selected_backbone, device=self.device) self.tracker.end_stage(f"V3_CNN_Init_{rotation_name}") self.tracker.start_stage(f"V3_Batch_CNN_{rotation_name}") padded_crops = [] H, W = drawing_sync.shape[:2] for p in proposals: x, y, bw, bh = p[0], p[1], p[2], p[3] margin_y = int(bh * context_margin_pct) margin_x = int(bw * context_margin_pct) x1 = max(0, x - margin_x) y1 = max(0, y - margin_y) x2 = min(W, x + bw + margin_x) y2 = min(H, y + bh + margin_y) padded_crops.append(drawing_sync[y1:y2, x1:x2]) T_vec = extractor.extract(tmpl_sync) P_vecs = extractor.extract_batch(padded_crops) T_vecs = T_vec.unsqueeze(0).expand(len(padded_crops), -1) scores_v2 = F.cosine_similarity(P_vecs, T_vecs, dim=1) self.tracker.end_stage(f"V3_Batch_CNN_{rotation_name}") self.tracker.start_stage(f"V3_Score_Fusion_{rotation_name}") for i, p in enumerate(proposals): s_v1 = float(p[4]) s_v2 = float(scores_v2[i].item()) if s_v2 >= v2_threshold: score_final = alpha * s_v1 + (1 - alpha) * s_v2 all_results.append({ "bbox": p[:4], "confidence": score_final, "score_v1": s_v1, "score_v2": s_v2, "rotation": rotation_name, "scale": float(p[5]), "variant_idx": idx }) self.tracker.end_stage(f"V3_Score_Fusion_{rotation_name}") if cancellation_state is not None: cancellation_state.check() # Gom cụm Soft-NMS self.tracker.start_stage("Postprocessing_Soft_NMS") nms_results = soft_nms( all_results, iou_threshold=iou_threshold, score_threshold=confidence_threshold, method="gaussian" ) self.tracker.end_stage("Postprocessing_Soft_NMS") # Local Refinement NCC if enable_local_refine and nms_results: self.tracker.start_stage("BBox_Local_Refinement") refined = [] for res in nms_results: if cancellation_state is not None: cancellation_state.check() x, y, w, h = res["bbox"] v_idx = res["variant_idx"] best_t_edge = tmpl_edges[v_idx] rx, ry, rw, rh, rscore = refine_bbox_local_search( drawing_edge, (x, y, w, h), best_t_edge, search_radius=8 ) res["bbox"] = (rx, ry, rw, rh) refined.append(res) nms_results = refined self.tracker.end_stage("BBox_Local_Refinement") # Xuất report hiệu năng report = self.tracker.get_report() report["num_proposals_total"] = len(all_results) report["num_detected"] = len(nms_results) return nms_results, report except DetectionCancelledException as e: if torch.cuda.is_available(): torch.cuda.empty_cache() raise e except Exception as e: self.clear() if isinstance(e, BOMDetectorException): raise e raise BOMDetectorException(f"Lỗi trong quá trình detect: {str(e)}")