import torch import torch.nn as nn import torch.nn.functional as F import torchvision.models as models from torchvision import transforms import cv2 import numpy as np from typing import Protocol, List from src.exceptions import ModelLoadException class FeatureExtractor(Protocol): def extract(self, img_gray: np.ndarray) -> torch.Tensor: ... def extract_batch(self, imgs: List[np.ndarray]) -> torch.Tensor: ... class DeepFeatureExtractor: """ Trích xuất đặc trưng từ các lớp sớm của ResNet18 nhằm giữ thông tin không gian góc/cạnh. """ TARGET_SIZE = (128, 128) def __init__(self, device: str = "cpu") -> None: self.device = torch.device(device) try: # Tránh cảnh báo deprecated bằng cách dùng weights model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT) self.extractor = torch.nn.Sequential(*list(model.children())[:6]) self.extractor.to(self.device) self.extractor.eval() # Khóa gradient bảo vệ rò rỉ bộ nhớ for p in self.extractor.parameters(): p.requires_grad_(False) except Exception as e: raise ModelLoadException(f"Không thể khởi tạo mô hình ResNet18: {e}") def extract(self, img_gray: np.ndarray) -> torch.Tensor: """Trích xuất một ảnh xám đơn lập thành vector đặc trưng phẳng chuẩn hóa.""" img_resized = cv2.resize(img_gray, self.TARGET_SIZE, interpolation=cv2.INTER_AREA) img_rgb = np.stack([img_resized] * 3, axis=2) transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) tensor = transform(img_rgb).unsqueeze(0).to(self.device) with torch.no_grad(): feat = self.extractor(tensor) # Làm phẳng bảo vệ spatial information feat_flat = feat.flatten() return F.normalize(feat_flat, dim=0) def extract_batch(self, imgs: list[np.ndarray]) -> torch.Tensor: """Trực thi đối sánh song song theo lô để tăng tốc xử lý cực đại.""" if not imgs: return torch.empty(0, device=self.device) tensors = [] transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) for img in imgs: resized = cv2.resize(img, self.TARGET_SIZE, interpolation=cv2.INTER_AREA) rgb = np.stack([resized] * 3, axis=2) tensors.append(transform(rgb)) batch = torch.stack(tensors).to(self.device) with torch.no_grad(): feats = self.extractor(batch) feats_flat = feats.flatten(start_dim=1) return F.normalize(feats_flat, dim=1) class DINOv2Extractor: """ Bộ trích xuất đặc trưng sử dụng mô hình nền tảng DINOv2. """ TARGET_SIZE = (224, 224) def __init__(self, device: str = "cpu") -> None: self.device = torch.device(device) try: # Load từ PyTorch Hub self.model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14') self.model.to(self.device) self.model.eval() for p in self.model.parameters(): p.requires_grad_(False) except Exception as e: raise ModelLoadException(f"Không thể khởi tạo mô hình DINOv2: {e}") def extract(self, img_gray: np.ndarray) -> torch.Tensor: img_resized = cv2.resize(img_gray, self.TARGET_SIZE, interpolation=cv2.INTER_AREA) img_rgb = np.stack([img_resized] * 3, axis=2) transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) tensor = transform(img_rgb).unsqueeze(0).to(self.device) with torch.no_grad(): feat = self.model(tensor) if isinstance(feat, dict): feat = feat["x_norm_clstoken"] return F.normalize(feat.flatten(), dim=0) def extract_batch(self, imgs: list[np.ndarray]) -> torch.Tensor: if not imgs: return torch.empty(0, device=self.device) tensors = [] transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) for img in imgs: resized = cv2.resize(img, self.TARGET_SIZE, interpolation=cv2.INTER_AREA) rgb = np.stack([resized] * 3, axis=2) tensors.append(transform(rgb)) batch = torch.stack(tensors).to(self.device) with torch.no_grad(): feats = self.model(batch) if isinstance(feats, dict): feats = feats["x_norm_clstoken"] return F.normalize(feats, dim=1) # Singleton cache _EXTRACTOR_CACHE = {} def get_shared_feature_extractor(backbone: str = "resnet18", device: str = "cpu") -> FeatureExtractor: """ Factory trả về Singleton extractor được cache kèm cơ chế Fallback thiết bị và mô hình. """ actual_device = device if device == "cuda" and not torch.cuda.is_available(): print("[Warning] Khởi chạy trên CUDA bất khả thi. Tự động Fallback sang CPU.") actual_device = "cpu" cache_key = (backbone, actual_device) if cache_key in _EXTRACTOR_CACHE: return _EXTRACTOR_CACHE[cache_key] try: if backbone == "dinov2": extractor = DINOv2Extractor(device=actual_device) else: extractor = DeepFeatureExtractor(device=actual_device) except ModelLoadException as e: print(f"[Warning] Khởi chạy model {backbone} thất bại: {e}.") if backbone == "dinov2": print("Tự động Fallback hạ cấp xuống ResNet18.") return get_shared_feature_extractor(backbone="resnet18", device=actual_device) else: if actual_device == "cpu": raise e print("Tự động Fallback hạ cấp xuống ResNet18 trên CPU.") return get_shared_feature_extractor(backbone="resnet18", device="cpu") except Exception as e: print(f"[Warning] Lỗi không mong đợi: {e}. Fallback sang ResNet18 CPU.") if backbone == "resnet18" and actual_device == "cpu": raise ModelLoadException(f"Không thể tải mô hình dự phòng ResNet18 trên CPU: {e}") from e return get_shared_feature_extractor(backbone="resnet18", device="cpu") _EXTRACTOR_CACHE[cache_key] = extractor return extractor def choose_extractor(template: np.ndarray, resnet_ext: FeatureExtractor, dino_ext: FeatureExtractor) -> FeatureExtractor: """ Lựa chọn Extractor phù hợp: Dưới 56px sử dụng ResNet18 để tránh token thưa của DINOv2. """ h, w = template.shape[:2] if min(h, w) < 56: return resnet_ext return dino_ext