3dai / trellis2 /modules /image_feature_extractor.py
developerskyebrowse's picture
switch DINOv3 to HF Hub via AutoModel.from_pretrained
0775b31
from typing import *
import torch
import torch.nn.functional as F
from torchvision import transforms
from transformers import AutoModel
import numpy as np
from PIL import Image
def _prepare_image_batch(image, target_size, transform):
if isinstance(image, torch.Tensor):
assert image.ndim == 4, "Image tensor should be batched (B, C, H, W)"
elif isinstance(image, list):
assert all(isinstance(i, Image.Image) for i in image)
image = [i.resize((target_size, target_size), Image.LANCZOS) for i in image]
image = [np.array(i.convert('RGB')).astype(np.float32) / 255 for i in image]
image = [torch.from_numpy(i).permute(2, 0, 1).float() for i in image]
image = torch.stack(image).cuda()
else:
raise ValueError(f"Unsupported type of image: {type(image)}")
return transform(image).cuda()
class DinoV2FeatureExtractor:
def __init__(self, model_name: str):
self.model_name = model_name
self.model = torch.hub.load('facebookresearch/dinov2', model_name, pretrained=True)
self.model.eval()
self.transform = transforms.Compose([
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
def to(self, device):
self.model.to(device)
def cuda(self):
self.model.cuda()
def cpu(self):
self.model.cpu()
@torch.no_grad()
def __call__(self, image: Union[torch.Tensor, List[Image.Image]]) -> torch.Tensor:
image = _prepare_image_batch(image, 518, self.transform)
features = self.model(image, is_training=True)['x_prenorm']
return F.layer_norm(features, features.shape[-1:])
class DinoV3FeatureExtractor:
def __init__(self, model_name: str, image_size=512):
self.model_name = model_name
self.model = AutoModel.from_pretrained(model_name)
self.model.eval()
self.image_size = image_size
self.transform = transforms.Compose([
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
def to(self, device):
self.model.to(device)
def cuda(self):
self.model.cuda()
def cpu(self):
self.model.cpu()
@torch.no_grad()
def __call__(self, image: Union[torch.Tensor, List[Image.Image]]) -> torch.Tensor:
image = _prepare_image_batch(image, self.image_size, self.transform)
outputs = self.model(pixel_values=image, output_hidden_states=True)
prenorm_features = outputs.hidden_states[-1]
return F.layer_norm(prenorm_features, prenorm_features.shape[-1:])