from typing import * import torch import torch.nn.functional as F from torchvision import transforms from transformers import AutoModel import numpy as np from PIL import Image def _prepare_image_batch(image, target_size, transform): if isinstance(image, torch.Tensor): assert image.ndim == 4, "Image tensor should be batched (B, C, H, W)" elif isinstance(image, list): assert all(isinstance(i, Image.Image) for i in image) image = [i.resize((target_size, target_size), Image.LANCZOS) for i in image] image = [np.array(i.convert('RGB')).astype(np.float32) / 255 for i in image] image = [torch.from_numpy(i).permute(2, 0, 1).float() for i in image] image = torch.stack(image).cuda() else: raise ValueError(f"Unsupported type of image: {type(image)}") return transform(image).cuda() class DinoV2FeatureExtractor: def __init__(self, model_name: str): self.model_name = model_name self.model = torch.hub.load('facebookresearch/dinov2', model_name, pretrained=True) self.model.eval() self.transform = transforms.Compose([ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) def to(self, device): self.model.to(device) def cuda(self): self.model.cuda() def cpu(self): self.model.cpu() @torch.no_grad() def __call__(self, image: Union[torch.Tensor, List[Image.Image]]) -> torch.Tensor: image = _prepare_image_batch(image, 518, self.transform) features = self.model(image, is_training=True)['x_prenorm'] return F.layer_norm(features, features.shape[-1:]) class DinoV3FeatureExtractor: def __init__(self, model_name: str, image_size=512): self.model_name = model_name self.model = AutoModel.from_pretrained(model_name) self.model.eval() self.image_size = image_size self.transform = transforms.Compose([ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) def to(self, device): self.model.to(device) def cuda(self): self.model.cuda() def cpu(self): self.model.cpu() @torch.no_grad() def __call__(self, image: Union[torch.Tensor, List[Image.Image]]) -> torch.Tensor: image = _prepare_image_batch(image, self.image_size, self.transform) outputs = self.model(pixel_values=image, output_hidden_states=True) prenorm_features = outputs.hidden_states[-1] return F.layer_norm(prenorm_features, prenorm_features.shape[-1:])