BOM_Detection / src /features.py
AI Bot
deploy: zero-shot bom detection
8da7bdd
Raw
History Blame Contribute Delete
7.37 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from torchvision import transforms
import cv2
import numpy as np
from typing import Protocol, List
from src.exceptions import ModelLoadException
class FeatureExtractor(Protocol):
def extract(self, img_gray: np.ndarray) -> torch.Tensor:
...
def extract_batch(self, imgs: List[np.ndarray]) -> torch.Tensor:
...
class DeepFeatureExtractor:
"""
Trích xuất đặc trưng từ các lớp sớm của ResNet18 nhằm giữ thông tin không gian góc/cạnh.
"""
TARGET_SIZE = (128, 128)
def __init__(self, device: str = "cpu") -> None:
self.device = torch.device(device)
try:
# Tránh cảnh báo deprecated bằng cách dùng weights
model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
self.extractor = torch.nn.Sequential(*list(model.children())[:6])
self.extractor.to(self.device)
self.extractor.eval()
# Khóa gradient bảo vệ rò rỉ bộ nhớ
for p in self.extractor.parameters():
p.requires_grad_(False)
except Exception as e:
raise ModelLoadException(f"Không thể khởi tạo mô hình ResNet18: {e}")
def extract(self, img_gray: np.ndarray) -> torch.Tensor:
"""Trích xuất một ảnh xám đơn lập thành vector đặc trưng phẳng chuẩn hóa."""
img_resized = cv2.resize(img_gray, self.TARGET_SIZE, interpolation=cv2.INTER_AREA)
img_rgb = np.stack([img_resized] * 3, axis=2)
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
tensor = transform(img_rgb).unsqueeze(0).to(self.device)
with torch.no_grad():
feat = self.extractor(tensor)
# Làm phẳng bảo vệ spatial information
feat_flat = feat.flatten()
return F.normalize(feat_flat, dim=0)
def extract_batch(self, imgs: list[np.ndarray]) -> torch.Tensor:
"""Trực thi đối sánh song song theo lô để tăng tốc xử lý cực đại."""
if not imgs:
return torch.empty(0, device=self.device)
tensors = []
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
for img in imgs:
resized = cv2.resize(img, self.TARGET_SIZE, interpolation=cv2.INTER_AREA)
rgb = np.stack([resized] * 3, axis=2)
tensors.append(transform(rgb))
batch = torch.stack(tensors).to(self.device)
with torch.no_grad():
feats = self.extractor(batch)
feats_flat = feats.flatten(start_dim=1)
return F.normalize(feats_flat, dim=1)
class DINOv2Extractor:
"""
Bộ trích xuất đặc trưng sử dụng mô hình nền tảng DINOv2.
"""
TARGET_SIZE = (224, 224)
def __init__(self, device: str = "cpu") -> None:
self.device = torch.device(device)
try:
# Load từ PyTorch Hub
self.model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14')
self.model.to(self.device)
self.model.eval()
for p in self.model.parameters():
p.requires_grad_(False)
except Exception as e:
raise ModelLoadException(f"Không thể khởi tạo mô hình DINOv2: {e}")
def extract(self, img_gray: np.ndarray) -> torch.Tensor:
img_resized = cv2.resize(img_gray, self.TARGET_SIZE, interpolation=cv2.INTER_AREA)
img_rgb = np.stack([img_resized] * 3, axis=2)
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
tensor = transform(img_rgb).unsqueeze(0).to(self.device)
with torch.no_grad():
feat = self.model(tensor)
if isinstance(feat, dict):
feat = feat["x_norm_clstoken"]
return F.normalize(feat.flatten(), dim=0)
def extract_batch(self, imgs: list[np.ndarray]) -> torch.Tensor:
if not imgs:
return torch.empty(0, device=self.device)
tensors = []
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
for img in imgs:
resized = cv2.resize(img, self.TARGET_SIZE, interpolation=cv2.INTER_AREA)
rgb = np.stack([resized] * 3, axis=2)
tensors.append(transform(rgb))
batch = torch.stack(tensors).to(self.device)
with torch.no_grad():
feats = self.model(batch)
if isinstance(feats, dict):
feats = feats["x_norm_clstoken"]
return F.normalize(feats, dim=1)
# Singleton cache
_EXTRACTOR_CACHE = {}
def get_shared_feature_extractor(backbone: str = "resnet18", device: str = "cpu") -> FeatureExtractor:
"""
Factory trả về Singleton extractor được cache kèm cơ chế Fallback thiết bị và mô hình.
"""
actual_device = device
if device == "cuda" and not torch.cuda.is_available():
print("[Warning] Khởi chạy trên CUDA bất khả thi. Tự động Fallback sang CPU.")
actual_device = "cpu"
cache_key = (backbone, actual_device)
if cache_key in _EXTRACTOR_CACHE:
return _EXTRACTOR_CACHE[cache_key]
try:
if backbone == "dinov2":
extractor = DINOv2Extractor(device=actual_device)
else:
extractor = DeepFeatureExtractor(device=actual_device)
except ModelLoadException as e:
print(f"[Warning] Khởi chạy model {backbone} thất bại: {e}.")
if backbone == "dinov2":
print("Tự động Fallback hạ cấp xuống ResNet18.")
return get_shared_feature_extractor(backbone="resnet18", device=actual_device)
else:
if actual_device == "cpu":
raise e
print("Tự động Fallback hạ cấp xuống ResNet18 trên CPU.")
return get_shared_feature_extractor(backbone="resnet18", device="cpu")
except Exception as e:
print(f"[Warning] Lỗi không mong đợi: {e}. Fallback sang ResNet18 CPU.")
if backbone == "resnet18" and actual_device == "cpu":
raise ModelLoadException(f"Không thể tải mô hình dự phòng ResNet18 trên CPU: {e}") from e
return get_shared_feature_extractor(backbone="resnet18", device="cpu")
_EXTRACTOR_CACHE[cache_key] = extractor
return extractor
def choose_extractor(template: np.ndarray, resnet_ext: FeatureExtractor, dino_ext: FeatureExtractor) -> FeatureExtractor:
"""
Lựa chọn Extractor phù hợp: Dưới 56px sử dụng ResNet18 để tránh token thưa của DINOv2.
"""
h, w = template.shape[:2]
if min(h, w) < 56:
return resnet_ext
return dino_ext