|
|
from typing import * |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from torchvision import transforms |
|
|
from transformers import DINOv3ViTModel |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
|
|
|
|
|
|
class DinoV2FeatureExtractor: |
|
|
""" |
|
|
Feature extractor for DINOv2 models. |
|
|
""" |
|
|
def __init__(self, model_name: str): |
|
|
self.model_name = model_name |
|
|
self.model = torch.hub.load('facebookresearch/dinov2', model_name, pretrained=True) |
|
|
self.model.eval() |
|
|
self.transform = transforms.Compose([ |
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
|
|
]) |
|
|
|
|
|
def to(self, device): |
|
|
self.model.to(device) |
|
|
|
|
|
def cuda(self): |
|
|
self.model.cuda() |
|
|
|
|
|
def cpu(self): |
|
|
self.model.cpu() |
|
|
|
|
|
@torch.no_grad() |
|
|
def __call__(self, image: Union[torch.Tensor, List[Image.Image]]) -> torch.Tensor: |
|
|
""" |
|
|
Extract features from the image. |
|
|
|
|
|
Args: |
|
|
image: A batch of images as a tensor of shape (B, C, H, W) or a list of PIL images. |
|
|
|
|
|
Returns: |
|
|
A tensor of shape (B, N, D) where N is the number of patches and D is the feature dimension. |
|
|
""" |
|
|
if isinstance(image, torch.Tensor): |
|
|
assert image.ndim == 4, "Image tensor should be batched (B, C, H, W)" |
|
|
elif isinstance(image, list): |
|
|
assert all(isinstance(i, Image.Image) for i in image), "Image list should be list of PIL images" |
|
|
image = [i.resize((518, 518), Image.LANCZOS) for i in image] |
|
|
image = [np.array(i.convert('RGB')).astype(np.float32) / 255 for i in image] |
|
|
image = [torch.from_numpy(i).permute(2, 0, 1).float() for i in image] |
|
|
image = torch.stack(image).cuda() |
|
|
else: |
|
|
raise ValueError(f"Unsupported type of image: {type(image)}") |
|
|
|
|
|
image = self.transform(image).cuda() |
|
|
features = self.model(image, is_training=True)['x_prenorm'] |
|
|
patchtokens = F.layer_norm(features, features.shape[-1:]) |
|
|
return patchtokens |
|
|
|
|
|
|
|
|
class DinoV3FeatureExtractor: |
|
|
""" |
|
|
Feature extractor for DINOv3 models. |
|
|
""" |
|
|
def __init__(self, model_name: str, image_size=512): |
|
|
self.model_name = model_name |
|
|
self.model = DINOv3ViTModel.from_pretrained(model_name) |
|
|
self.model.eval() |
|
|
self.image_size = image_size |
|
|
self.transform = transforms.Compose([ |
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
|
|
]) |
|
|
|
|
|
def to(self, device): |
|
|
self.model.to(device) |
|
|
|
|
|
def cuda(self): |
|
|
self.model.cuda() |
|
|
|
|
|
def cpu(self): |
|
|
self.model.cpu() |
|
|
|
|
|
def extract_features(self, image: torch.Tensor) -> torch.Tensor: |
|
|
image = image.to(self.model.embeddings.patch_embeddings.weight.dtype) |
|
|
hidden_states = self.model.embeddings(image, bool_masked_pos=None) |
|
|
position_embeddings = self.model.rope_embeddings(image) |
|
|
|
|
|
for i, layer_module in enumerate(self.model.layer): |
|
|
hidden_states = layer_module( |
|
|
hidden_states, |
|
|
position_embeddings=position_embeddings, |
|
|
) |
|
|
|
|
|
return F.layer_norm(hidden_states, hidden_states.shape[-1:]) |
|
|
|
|
|
@torch.no_grad() |
|
|
def __call__(self, image: Union[torch.Tensor, List[Image.Image]]) -> torch.Tensor: |
|
|
""" |
|
|
Extract features from the image. |
|
|
|
|
|
Args: |
|
|
image: A batch of images as a tensor of shape (B, C, H, W) or a list of PIL images. |
|
|
|
|
|
Returns: |
|
|
A tensor of shape (B, N, D) where N is the number of patches and D is the feature dimension. |
|
|
""" |
|
|
if isinstance(image, torch.Tensor): |
|
|
assert image.ndim == 4, "Image tensor should be batched (B, C, H, W)" |
|
|
elif isinstance(image, list): |
|
|
assert all(isinstance(i, Image.Image) for i in image), "Image list should be list of PIL images" |
|
|
image = [i.resize((self.image_size, self.image_size), Image.LANCZOS) for i in image] |
|
|
image = [np.array(i.convert('RGB')).astype(np.float32) / 255 for i in image] |
|
|
image = [torch.from_numpy(i).permute(2, 0, 1).float() for i in image] |
|
|
image = torch.stack(image).cuda() |
|
|
else: |
|
|
raise ValueError(f"Unsupported type of image: {type(image)}") |
|
|
|
|
|
image = self.transform(image).cuda() |
|
|
features = self.extract_features(image) |
|
|
return features |
|
|
|