Spaces:
Sleeping
Sleeping
File size: 7,367 Bytes
8da7bdd | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 | import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from torchvision import transforms
import cv2
import numpy as np
from typing import Protocol, List
from src.exceptions import ModelLoadException
class FeatureExtractor(Protocol):
def extract(self, img_gray: np.ndarray) -> torch.Tensor:
...
def extract_batch(self, imgs: List[np.ndarray]) -> torch.Tensor:
...
class DeepFeatureExtractor:
"""
Trích xuất đặc trưng từ các lớp sớm của ResNet18 nhằm giữ thông tin không gian góc/cạnh.
"""
TARGET_SIZE = (128, 128)
def __init__(self, device: str = "cpu") -> None:
self.device = torch.device(device)
try:
# Tránh cảnh báo deprecated bằng cách dùng weights
model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
self.extractor = torch.nn.Sequential(*list(model.children())[:6])
self.extractor.to(self.device)
self.extractor.eval()
# Khóa gradient bảo vệ rò rỉ bộ nhớ
for p in self.extractor.parameters():
p.requires_grad_(False)
except Exception as e:
raise ModelLoadException(f"Không thể khởi tạo mô hình ResNet18: {e}")
def extract(self, img_gray: np.ndarray) -> torch.Tensor:
"""Trích xuất một ảnh xám đơn lập thành vector đặc trưng phẳng chuẩn hóa."""
img_resized = cv2.resize(img_gray, self.TARGET_SIZE, interpolation=cv2.INTER_AREA)
img_rgb = np.stack([img_resized] * 3, axis=2)
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
tensor = transform(img_rgb).unsqueeze(0).to(self.device)
with torch.no_grad():
feat = self.extractor(tensor)
# Làm phẳng bảo vệ spatial information
feat_flat = feat.flatten()
return F.normalize(feat_flat, dim=0)
def extract_batch(self, imgs: list[np.ndarray]) -> torch.Tensor:
"""Trực thi đối sánh song song theo lô để tăng tốc xử lý cực đại."""
if not imgs:
return torch.empty(0, device=self.device)
tensors = []
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
for img in imgs:
resized = cv2.resize(img, self.TARGET_SIZE, interpolation=cv2.INTER_AREA)
rgb = np.stack([resized] * 3, axis=2)
tensors.append(transform(rgb))
batch = torch.stack(tensors).to(self.device)
with torch.no_grad():
feats = self.extractor(batch)
feats_flat = feats.flatten(start_dim=1)
return F.normalize(feats_flat, dim=1)
class DINOv2Extractor:
"""
Bộ trích xuất đặc trưng sử dụng mô hình nền tảng DINOv2.
"""
TARGET_SIZE = (224, 224)
def __init__(self, device: str = "cpu") -> None:
self.device = torch.device(device)
try:
# Load từ PyTorch Hub
self.model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14')
self.model.to(self.device)
self.model.eval()
for p in self.model.parameters():
p.requires_grad_(False)
except Exception as e:
raise ModelLoadException(f"Không thể khởi tạo mô hình DINOv2: {e}")
def extract(self, img_gray: np.ndarray) -> torch.Tensor:
img_resized = cv2.resize(img_gray, self.TARGET_SIZE, interpolation=cv2.INTER_AREA)
img_rgb = np.stack([img_resized] * 3, axis=2)
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
tensor = transform(img_rgb).unsqueeze(0).to(self.device)
with torch.no_grad():
feat = self.model(tensor)
if isinstance(feat, dict):
feat = feat["x_norm_clstoken"]
return F.normalize(feat.flatten(), dim=0)
def extract_batch(self, imgs: list[np.ndarray]) -> torch.Tensor:
if not imgs:
return torch.empty(0, device=self.device)
tensors = []
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
for img in imgs:
resized = cv2.resize(img, self.TARGET_SIZE, interpolation=cv2.INTER_AREA)
rgb = np.stack([resized] * 3, axis=2)
tensors.append(transform(rgb))
batch = torch.stack(tensors).to(self.device)
with torch.no_grad():
feats = self.model(batch)
if isinstance(feats, dict):
feats = feats["x_norm_clstoken"]
return F.normalize(feats, dim=1)
# Singleton cache
_EXTRACTOR_CACHE = {}
def get_shared_feature_extractor(backbone: str = "resnet18", device: str = "cpu") -> FeatureExtractor:
"""
Factory trả về Singleton extractor được cache kèm cơ chế Fallback thiết bị và mô hình.
"""
actual_device = device
if device == "cuda" and not torch.cuda.is_available():
print("[Warning] Khởi chạy trên CUDA bất khả thi. Tự động Fallback sang CPU.")
actual_device = "cpu"
cache_key = (backbone, actual_device)
if cache_key in _EXTRACTOR_CACHE:
return _EXTRACTOR_CACHE[cache_key]
try:
if backbone == "dinov2":
extractor = DINOv2Extractor(device=actual_device)
else:
extractor = DeepFeatureExtractor(device=actual_device)
except ModelLoadException as e:
print(f"[Warning] Khởi chạy model {backbone} thất bại: {e}.")
if backbone == "dinov2":
print("Tự động Fallback hạ cấp xuống ResNet18.")
return get_shared_feature_extractor(backbone="resnet18", device=actual_device)
else:
if actual_device == "cpu":
raise e
print("Tự động Fallback hạ cấp xuống ResNet18 trên CPU.")
return get_shared_feature_extractor(backbone="resnet18", device="cpu")
except Exception as e:
print(f"[Warning] Lỗi không mong đợi: {e}. Fallback sang ResNet18 CPU.")
if backbone == "resnet18" and actual_device == "cpu":
raise ModelLoadException(f"Không thể tải mô hình dự phòng ResNet18 trên CPU: {e}") from e
return get_shared_feature_extractor(backbone="resnet18", device="cpu")
_EXTRACTOR_CACHE[cache_key] = extractor
return extractor
def choose_extractor(template: np.ndarray, resnet_ext: FeatureExtractor, dino_ext: FeatureExtractor) -> FeatureExtractor:
"""
Lựa chọn Extractor phù hợp: Dưới 56px sử dụng ResNet18 để tránh token thưa của DINOv2.
"""
h, w = template.shape[:2]
if min(h, w) < 56:
return resnet_ext
return dino_ext
|