Spaces:
Sleeping
Sleeping
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| import cv2 | |
| from typing import List, Dict, Tuple, Any, Optional | |
| from src.exceptions import ( | |
| BOMDetectorException, | |
| InvalidImageException, | |
| CancellationState, | |
| DetectionCancelledException | |
| ) | |
| from src.metrics import PerformanceTracker | |
| from src.io_validation import validate_inputs | |
| from src.preprocessing import ( | |
| synchronize_polarity, | |
| preprocess_for_matching, | |
| filter_informative_proposals | |
| ) | |
| from src.features import get_shared_feature_extractor | |
| from src.engines import multiscale_template_match, soft_nms, refine_bbox_local_search | |
| def generate_template_variants( | |
| template: np.ndarray, | |
| ) -> List[Tuple[np.ndarray, str]]: | |
| """Sinh 4 biến thể xoay mẫu nhằm bảo vệ vẽ bản không bị xoay tốn bộ nhớ.""" | |
| variants = [ | |
| (template.copy(), "R0"), | |
| (cv2.rotate(template, cv2.ROTATE_90_CLOCKWISE), "R90"), | |
| (cv2.rotate(template, cv2.ROTATE_180), "R180"), | |
| (cv2.rotate(template, cv2.ROTATE_90_COUNTERCLOCKWISE), "R270"), | |
| ] | |
| return variants | |
| class PatternDetector: | |
| """ | |
| Orchestrator điều phối toàn bộ luồng khớp và đo đạc hiệu suất. | |
| """ | |
| def __init__(self, device: str = "cpu") -> None: | |
| self.device = device | |
| self.drawing_raw: np.ndarray = None | |
| self.drawing_gray: np.ndarray = None | |
| self.templates_variants: List[Tuple[np.ndarray, str]] = [] | |
| self.tracker = PerformanceTracker() | |
| def clear(self) -> None: | |
| """Thu hồi triệt để bộ nhớ tránh leak RAM/VRAM.""" | |
| self.drawing_raw = None | |
| self.drawing_gray = None | |
| self.templates_variants = [] | |
| self.tracker = PerformanceTracker() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| def load_drawing(self, drawing_img: np.ndarray) -> None: | |
| """Đọc chuẩn hóa drawing.""" | |
| try: | |
| self.tracker.start_stage("load_and_normalize_drawing") | |
| if drawing_img is None or drawing_img.size == 0: | |
| raise InvalidImageException("Ảnh bản vẽ đầu vào trống.") | |
| self.drawing_raw = drawing_img.copy() | |
| if drawing_img.ndim == 3: | |
| self.drawing_gray = cv2.cvtColor(drawing_img, cv2.COLOR_BGR2GRAY) | |
| else: | |
| self.drawing_gray = drawing_img.copy() | |
| self.tracker.end_stage("load_and_normalize_drawing") | |
| except Exception as e: | |
| self.clear() | |
| if isinstance(e, BOMDetectorException): | |
| raise e | |
| raise InvalidImageException(f"Lỗi nạp bản vẽ: {str(e)}") | |
| def add_templates(self, templates: List[np.ndarray], with_rotation: bool = False) -> None: | |
| """Nạp và đồng bộ hóa phân cực template.""" | |
| try: | |
| self.tracker.start_stage("add_templates") | |
| if not templates: | |
| raise BOMDetectorException("Không có template.") | |
| self.templates_variants = [] | |
| for tmpl in templates: | |
| if tmpl is None or tmpl.size == 0: | |
| raise InvalidImageException("Ảnh mẫu trống.") | |
| tmpl_gray = cv2.cvtColor(tmpl, cv2.COLOR_BGR2GRAY) if tmpl.ndim == 3 else tmpl.copy() | |
| validate_inputs(self.drawing_gray, tmpl_gray) | |
| _, tmpl_sync = synchronize_polarity(self.drawing_gray, tmpl_gray) | |
| if with_rotation: | |
| self.templates_variants.extend(generate_template_variants(tmpl_sync)) | |
| else: | |
| self.templates_variants.append((tmpl_sync, "R0")) | |
| self.tracker.end_stage("add_templates") | |
| except Exception as e: | |
| self.clear() | |
| if isinstance(e, BOMDetectorException): | |
| raise e | |
| raise BOMDetectorException(f"Lỗi đăng ký template: {str(e)}") | |
| def detect( | |
| self, | |
| mode: str = "v3", | |
| confidence_threshold: float = 0.75, | |
| v1_threshold: float = 0.50, | |
| v2_threshold: float = 0.80, | |
| alpha: float = 0.30, | |
| iou_threshold: float = 0.30, | |
| enable_local_refine: bool = False, | |
| variance_std_threshold: float = 5.0, | |
| context_margin_pct: float = 0.15, | |
| extractor_type: str = "auto", | |
| cancellation_state: Optional[CancellationState] = None, | |
| ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]: | |
| """ | |
| Hàm suy luận trung tâm của PatternDetector hỗ trợ đo đạc chi tiết và dọn dẹp graceful. | |
| """ | |
| if cancellation_state is not None: | |
| cancellation_state.check() | |
| if self.drawing_gray is None: | |
| raise BOMDetectorException("Drawing chưa được nạp.") | |
| if not self.templates_variants: | |
| raise BOMDetectorException("Template chưa được đăng ký.") | |
| all_results = [] | |
| try: | |
| if cancellation_state is not None: | |
| cancellation_state.check() | |
| # 1. Hoist Drawing Polarity Synchronization | |
| self.tracker.start_stage("Prep_Drawing_Polarity") | |
| drawing_sync, _ = synchronize_polarity(self.drawing_gray, self.drawing_gray) | |
| self.tracker.end_stage("Prep_Drawing_Polarity") | |
| # 2. Hoist Drawing Edge Preprocessing | |
| self.tracker.start_stage("Prep_Drawing_Edges") | |
| drawing_edge = preprocess_for_matching(drawing_sync, method="dilated_edge") | |
| self.tracker.end_stage("Prep_Drawing_Edges") | |
| # 3. Precompute Template Edges for all variants using list aligned with self.templates_variants | |
| self.tracker.start_stage("Prep_Template_Edges") | |
| tmpl_edges = [ | |
| preprocess_for_matching(tmpl, method="dilated_edge") | |
| for tmpl, _ in self.templates_variants | |
| ] | |
| self.tracker.end_stage("Prep_Template_Edges") | |
| if cancellation_state is not None: | |
| cancellation_state.check() | |
| for idx, (tmpl, rotation_name) in enumerate(self.templates_variants): | |
| if cancellation_state is not None: | |
| cancellation_state.check() | |
| tmpl_sync = tmpl # Already synchronized in add_templates() | |
| tmpl_edge = tmpl_edges[idx] | |
| if mode == "v1": | |
| self.tracker.start_stage(f"V1_Prep_{rotation_name}") | |
| # Already precomputed! | |
| self.tracker.end_stage(f"V1_Prep_{rotation_name}") | |
| self.tracker.start_stage(f"V1_Matching_{rotation_name}") | |
| proposals = multiscale_template_match( | |
| drawing_edge, tmpl_edge, threshold=v1_threshold, cancellation_state=cancellation_state | |
| ) | |
| self.tracker.end_stage(f"V1_Matching_{rotation_name}") | |
| for p in proposals: | |
| all_results.append({ | |
| "bbox": p[:4], | |
| "confidence": float(p[4]), | |
| "rotation": rotation_name, | |
| "scale": float(p[5]), | |
| "variant_idx": idx | |
| }) | |
| elif mode == "v2": | |
| self.tracker.start_stage(f"V2_Candidate_Gen_{rotation_name}") | |
| proposals = multiscale_template_match( | |
| drawing_edge, tmpl_edge, threshold=0.35, cancellation_state=cancellation_state | |
| ) | |
| self.tracker.end_stage(f"V2_Candidate_Gen_{rotation_name}") | |
| if not proposals: | |
| continue | |
| # Coarse NMS to prune candidates before heavy CNN | |
| proposal_dicts = [{ | |
| "bbox": p[:4], | |
| "confidence": float(p[4]), | |
| "scale": float(p[5]) | |
| } for p in proposals] | |
| pruned_proposal_dicts = soft_nms( | |
| proposal_dicts, | |
| iou_threshold=0.5, | |
| score_threshold=0.35, | |
| method="gaussian" | |
| ) | |
| proposals = [ | |
| (pd["bbox"][0], pd["bbox"][1], pd["bbox"][2], pd["bbox"][3], pd["confidence"], pd["scale"]) | |
| for pd in pruned_proposal_dicts | |
| ] | |
| self.tracker.start_stage(f"V2_Blank_Filtering_{rotation_name}") | |
| proposals = filter_informative_proposals( | |
| proposals, drawing_sync, std_threshold=variance_std_threshold | |
| ) | |
| self.tracker.end_stage(f"V2_Blank_Filtering_{rotation_name}") | |
| if not proposals: | |
| continue | |
| if cancellation_state is not None: | |
| cancellation_state.check() | |
| self.tracker.start_stage(f"V2_CNN_Init_{rotation_name}") | |
| selected_backbone = extractor_type | |
| if extractor_type == "auto": | |
| th, tw = tmpl_sync.shape[:2] | |
| selected_backbone = "resnet18" if min(th, tw) < 56 else "dinov2" | |
| extractor = get_shared_feature_extractor(backbone=selected_backbone, device=self.device) | |
| self.tracker.end_stage(f"V2_CNN_Init_{rotation_name}") | |
| self.tracker.start_stage(f"V2_Batch_CNN_{rotation_name}") | |
| crops = [] | |
| H, W = drawing_sync.shape[:2] | |
| for p in proposals: | |
| x, y, bw, bh = p[0], p[1], p[2], p[3] | |
| margin_y = int(bh * context_margin_pct) | |
| margin_x = int(bw * context_margin_pct) | |
| x1 = max(0, x - margin_x) | |
| y1 = max(0, y - margin_y) | |
| x2 = min(W, x + bw + margin_x) | |
| y2 = min(H, y + bh + margin_y) | |
| crops.append(drawing_sync[y1:y2, x1:x2]) | |
| T_vec = extractor.extract(tmpl_sync) | |
| P_vecs = extractor.extract_batch(crops) | |
| T_vecs = T_vec.unsqueeze(0).expand(len(crops), -1) | |
| scores_v2 = F.cosine_similarity(P_vecs, T_vecs, dim=1) | |
| self.tracker.end_stage(f"V2_Batch_CNN_{rotation_name}") | |
| for i, p in enumerate(proposals): | |
| s_v2 = float(scores_v2[i].item()) | |
| if s_v2 >= v2_threshold: | |
| all_results.append({ | |
| "bbox": p[:4], | |
| "confidence": s_v2, | |
| "rotation": rotation_name, | |
| "scale": float(p[5]), | |
| "variant_idx": idx | |
| }) | |
| elif mode == "v3": | |
| self.tracker.start_stage(f"V3_Coarse_V1_{rotation_name}") | |
| proposals = multiscale_template_match( | |
| drawing_edge, tmpl_edge, threshold=v1_threshold, cancellation_state=cancellation_state | |
| ) | |
| self.tracker.end_stage(f"V3_Coarse_V1_{rotation_name}") | |
| if not proposals: | |
| continue | |
| # Coarse NMS to prune candidates before heavy CNN | |
| proposal_dicts = [{ | |
| "bbox": p[:4], | |
| "confidence": float(p[4]), | |
| "scale": float(p[5]) | |
| } for p in proposals] | |
| pruned_proposal_dicts = soft_nms( | |
| proposal_dicts, | |
| iou_threshold=0.5, | |
| score_threshold=v1_threshold, | |
| method="gaussian" | |
| ) | |
| proposals = [ | |
| (pd["bbox"][0], pd["bbox"][1], pd["bbox"][2], pd["bbox"][3], pd["confidence"], pd["scale"]) | |
| for pd in pruned_proposal_dicts | |
| ] | |
| self.tracker.start_stage(f"V3_Blank_Filtering_{rotation_name}") | |
| proposals = filter_informative_proposals( | |
| proposals, drawing_sync, std_threshold=variance_std_threshold | |
| ) | |
| self.tracker.end_stage(f"V3_Blank_Filtering_{rotation_name}") | |
| if not proposals: | |
| continue | |
| if cancellation_state is not None: | |
| cancellation_state.check() | |
| self.tracker.start_stage(f"V3_CNN_Init_{rotation_name}") | |
| selected_backbone = extractor_type | |
| if extractor_type == "auto": | |
| th, tw = tmpl_sync.shape[:2] | |
| selected_backbone = "resnet18" if min(th, tw) < 56 else "dinov2" | |
| extractor = get_shared_feature_extractor(backbone=selected_backbone, device=self.device) | |
| self.tracker.end_stage(f"V3_CNN_Init_{rotation_name}") | |
| self.tracker.start_stage(f"V3_Batch_CNN_{rotation_name}") | |
| padded_crops = [] | |
| H, W = drawing_sync.shape[:2] | |
| for p in proposals: | |
| x, y, bw, bh = p[0], p[1], p[2], p[3] | |
| margin_y = int(bh * context_margin_pct) | |
| margin_x = int(bw * context_margin_pct) | |
| x1 = max(0, x - margin_x) | |
| y1 = max(0, y - margin_y) | |
| x2 = min(W, x + bw + margin_x) | |
| y2 = min(H, y + bh + margin_y) | |
| padded_crops.append(drawing_sync[y1:y2, x1:x2]) | |
| T_vec = extractor.extract(tmpl_sync) | |
| P_vecs = extractor.extract_batch(padded_crops) | |
| T_vecs = T_vec.unsqueeze(0).expand(len(padded_crops), -1) | |
| scores_v2 = F.cosine_similarity(P_vecs, T_vecs, dim=1) | |
| self.tracker.end_stage(f"V3_Batch_CNN_{rotation_name}") | |
| self.tracker.start_stage(f"V3_Score_Fusion_{rotation_name}") | |
| for i, p in enumerate(proposals): | |
| s_v1 = float(p[4]) | |
| s_v2 = float(scores_v2[i].item()) | |
| if s_v2 >= v2_threshold: | |
| score_final = alpha * s_v1 + (1 - alpha) * s_v2 | |
| all_results.append({ | |
| "bbox": p[:4], | |
| "confidence": score_final, | |
| "score_v1": s_v1, | |
| "score_v2": s_v2, | |
| "rotation": rotation_name, | |
| "scale": float(p[5]), | |
| "variant_idx": idx | |
| }) | |
| self.tracker.end_stage(f"V3_Score_Fusion_{rotation_name}") | |
| if cancellation_state is not None: | |
| cancellation_state.check() | |
| # Gom cụm Soft-NMS | |
| self.tracker.start_stage("Postprocessing_Soft_NMS") | |
| nms_results = soft_nms( | |
| all_results, | |
| iou_threshold=iou_threshold, | |
| score_threshold=confidence_threshold, | |
| method="gaussian" | |
| ) | |
| self.tracker.end_stage("Postprocessing_Soft_NMS") | |
| # Local Refinement NCC | |
| if enable_local_refine and nms_results: | |
| self.tracker.start_stage("BBox_Local_Refinement") | |
| refined = [] | |
| for res in nms_results: | |
| if cancellation_state is not None: | |
| cancellation_state.check() | |
| x, y, w, h = res["bbox"] | |
| v_idx = res["variant_idx"] | |
| best_t_edge = tmpl_edges[v_idx] | |
| rx, ry, rw, rh, rscore = refine_bbox_local_search( | |
| drawing_edge, (x, y, w, h), best_t_edge, search_radius=8 | |
| ) | |
| res["bbox"] = (rx, ry, rw, rh) | |
| refined.append(res) | |
| nms_results = refined | |
| self.tracker.end_stage("BBox_Local_Refinement") | |
| # Xuất report hiệu năng | |
| report = self.tracker.get_report() | |
| report["num_proposals_total"] = len(all_results) | |
| report["num_detected"] = len(nms_results) | |
| return nms_results, report | |
| except DetectionCancelledException as e: | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| raise e | |
| except Exception as e: | |
| self.clear() | |
| if isinstance(e, BOMDetectorException): | |
| raise e | |
| raise BOMDetectorException(f"Lỗi trong quá trình detect: {str(e)}") | |