File size: 7,367 Bytes
8da7bdd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from torchvision import transforms
import cv2
import numpy as np
from typing import Protocol, List
from src.exceptions import ModelLoadException

class FeatureExtractor(Protocol):
    def extract(self, img_gray: np.ndarray) -> torch.Tensor:
        ...
    def extract_batch(self, imgs: List[np.ndarray]) -> torch.Tensor:
        ...

class DeepFeatureExtractor:
    """
    Trích xuất đặc trưng từ các lớp sớm của ResNet18 nhằm giữ thông tin không gian góc/cạnh.
    """
    TARGET_SIZE = (128, 128)

    def __init__(self, device: str = "cpu") -> None:
        self.device = torch.device(device)
        try:
            # Tránh cảnh báo deprecated bằng cách dùng weights
            model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
            self.extractor = torch.nn.Sequential(*list(model.children())[:6])
            self.extractor.to(self.device)
            self.extractor.eval()
            
            # Khóa gradient bảo vệ rò rỉ bộ nhớ
            for p in self.extractor.parameters():
                p.requires_grad_(False)
        except Exception as e:
            raise ModelLoadException(f"Không thể khởi tạo mô hình ResNet18: {e}")

    def extract(self, img_gray: np.ndarray) -> torch.Tensor:
        """Trích xuất một ảnh xám đơn lập thành vector đặc trưng phẳng chuẩn hóa."""
        img_resized = cv2.resize(img_gray, self.TARGET_SIZE, interpolation=cv2.INTER_AREA)
        img_rgb = np.stack([img_resized] * 3, axis=2)

        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225]),
        ])
        tensor = transform(img_rgb).unsqueeze(0).to(self.device)

        with torch.no_grad():
            feat = self.extractor(tensor)

        # Làm phẳng bảo vệ spatial information
        feat_flat = feat.flatten()
        return F.normalize(feat_flat, dim=0)

    def extract_batch(self, imgs: list[np.ndarray]) -> torch.Tensor:
        """Trực thi đối sánh song song theo lô để tăng tốc xử lý cực đại."""
        if not imgs:
            return torch.empty(0, device=self.device)
            
        tensors = []
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225]),
        ])
        for img in imgs:
            resized = cv2.resize(img, self.TARGET_SIZE, interpolation=cv2.INTER_AREA)
            rgb = np.stack([resized] * 3, axis=2)
            tensors.append(transform(rgb))

        batch = torch.stack(tensors).to(self.device)
        with torch.no_grad():
            feats = self.extractor(batch)

        feats_flat = feats.flatten(start_dim=1)
        return F.normalize(feats_flat, dim=1)


class DINOv2Extractor:
    """
    Bộ trích xuất đặc trưng sử dụng mô hình nền tảng DINOv2.
    """
    TARGET_SIZE = (224, 224) 

    def __init__(self, device: str = "cpu") -> None:
        self.device = torch.device(device)
        try:
            # Load từ PyTorch Hub
            self.model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14')
            self.model.to(self.device)
            self.model.eval()
            for p in self.model.parameters():
                p.requires_grad_(False)
        except Exception as e:
            raise ModelLoadException(f"Không thể khởi tạo mô hình DINOv2: {e}")

    def extract(self, img_gray: np.ndarray) -> torch.Tensor:
        img_resized = cv2.resize(img_gray, self.TARGET_SIZE, interpolation=cv2.INTER_AREA)
        img_rgb = np.stack([img_resized] * 3, axis=2)

        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225]),
        ])
        tensor = transform(img_rgb).unsqueeze(0).to(self.device)

        with torch.no_grad():
            feat = self.model(tensor)
            if isinstance(feat, dict):
                feat = feat["x_norm_clstoken"]

        return F.normalize(feat.flatten(), dim=0)

    def extract_batch(self, imgs: list[np.ndarray]) -> torch.Tensor:
        if not imgs:
            return torch.empty(0, device=self.device)
            
        tensors = []
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225]),
        ])
        for img in imgs:
            resized = cv2.resize(img, self.TARGET_SIZE, interpolation=cv2.INTER_AREA)
            rgb = np.stack([resized] * 3, axis=2)
            tensors.append(transform(rgb))

        batch = torch.stack(tensors).to(self.device)
        with torch.no_grad():
            feats = self.model(batch)
            if isinstance(feats, dict):
                feats = feats["x_norm_clstoken"]

        return F.normalize(feats, dim=1)


# Singleton cache
_EXTRACTOR_CACHE = {}

def get_shared_feature_extractor(backbone: str = "resnet18", device: str = "cpu") -> FeatureExtractor:
    """
    Factory trả về Singleton extractor được cache kèm cơ chế Fallback thiết bị và mô hình.
    """
    actual_device = device
    if device == "cuda" and not torch.cuda.is_available():
        print("[Warning] Khởi chạy trên CUDA bất khả thi. Tự động Fallback sang CPU.")
        actual_device = "cpu"
        
    cache_key = (backbone, actual_device)
    if cache_key in _EXTRACTOR_CACHE:
        return _EXTRACTOR_CACHE[cache_key]

    try:
        if backbone == "dinov2":
            extractor = DINOv2Extractor(device=actual_device)
        else:
            extractor = DeepFeatureExtractor(device=actual_device)
    except ModelLoadException as e:
        print(f"[Warning] Khởi chạy model {backbone} thất bại: {e}.")
        if backbone == "dinov2":
            print("Tự động Fallback hạ cấp xuống ResNet18.")
            return get_shared_feature_extractor(backbone="resnet18", device=actual_device)
        else:
            if actual_device == "cpu":
                raise e
            print("Tự động Fallback hạ cấp xuống ResNet18 trên CPU.")
            return get_shared_feature_extractor(backbone="resnet18", device="cpu")
    except Exception as e:
        print(f"[Warning] Lỗi không mong đợi: {e}. Fallback sang ResNet18 CPU.")
        if backbone == "resnet18" and actual_device == "cpu":
            raise ModelLoadException(f"Không thể tải mô hình dự phòng ResNet18 trên CPU: {e}") from e
        return get_shared_feature_extractor(backbone="resnet18", device="cpu")

    _EXTRACTOR_CACHE[cache_key] = extractor
    return extractor

def choose_extractor(template: np.ndarray, resnet_ext: FeatureExtractor, dino_ext: FeatureExtractor) -> FeatureExtractor:
    """
    Lựa chọn Extractor phù hợp: Dưới 56px sử dụng ResNet18 để tránh token thưa của DINOv2.
    """
    h, w = template.shape[:2]
    if min(h, w) < 56:
        return resnet_ext
    return dino_ext