Spaces:
Running on Zero
Running on Zero
| from typing import * | |
| import torch | |
| import torch.nn.functional as F | |
| from torchvision import transforms | |
| from transformers import AutoModel | |
| import numpy as np | |
| from PIL import Image | |
| def _prepare_image_batch(image, target_size, transform): | |
| if isinstance(image, torch.Tensor): | |
| assert image.ndim == 4, "Image tensor should be batched (B, C, H, W)" | |
| elif isinstance(image, list): | |
| assert all(isinstance(i, Image.Image) for i in image) | |
| image = [i.resize((target_size, target_size), Image.LANCZOS) for i in image] | |
| image = [np.array(i.convert('RGB')).astype(np.float32) / 255 for i in image] | |
| image = [torch.from_numpy(i).permute(2, 0, 1).float() for i in image] | |
| image = torch.stack(image).cuda() | |
| else: | |
| raise ValueError(f"Unsupported type of image: {type(image)}") | |
| return transform(image).cuda() | |
| class DinoV2FeatureExtractor: | |
| def __init__(self, model_name: str): | |
| self.model_name = model_name | |
| self.model = torch.hub.load('facebookresearch/dinov2', model_name, pretrained=True) | |
| self.model.eval() | |
| self.transform = transforms.Compose([ | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
| ]) | |
| def to(self, device): | |
| self.model.to(device) | |
| def cuda(self): | |
| self.model.cuda() | |
| def cpu(self): | |
| self.model.cpu() | |
| def __call__(self, image: Union[torch.Tensor, List[Image.Image]]) -> torch.Tensor: | |
| image = _prepare_image_batch(image, 518, self.transform) | |
| features = self.model(image, is_training=True)['x_prenorm'] | |
| return F.layer_norm(features, features.shape[-1:]) | |
| class DinoV3FeatureExtractor: | |
| def __init__(self, model_name: str, image_size=512): | |
| self.model_name = model_name | |
| self.model = AutoModel.from_pretrained(model_name) | |
| self.model.eval() | |
| self.image_size = image_size | |
| self.transform = transforms.Compose([ | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
| ]) | |
| def to(self, device): | |
| self.model.to(device) | |
| def cuda(self): | |
| self.model.cuda() | |
| def cpu(self): | |
| self.model.cpu() | |
| def __call__(self, image: Union[torch.Tensor, List[Image.Image]]) -> torch.Tensor: | |
| image = _prepare_image_batch(image, self.image_size, self.transform) | |
| outputs = self.model(pixel_values=image, output_hidden_states=True) | |
| prenorm_features = outputs.hidden_states[-1] | |
| return F.layer_norm(prenorm_features, prenorm_features.shape[-1:]) | |