phiph's picture
Upload folder using huggingface_hub
7382c66 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from copy import deepcopy
from math import (
ceil,
sqrt
)
from huggingface_hub import PyTorchModelHubMixin
import torchvision.transforms.v2.functional as TF
from .dinov2 import DINOViT
from .vit_w_esphere import ViT_w_Esphere
from .sphere import Sphere
IMAGENET_DATASET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_DATASET_STD = (0.229, 0.224, 0.225)
class SphereViT(nn.Module, PyTorchModelHubMixin):
def __init__(self, config):
super().__init__()
self.config = config
self.dino = DINOViT()
self.vit_w_esphere = ViT_w_Esphere(config['spherevit']['vit_w_esphere'])
feature_slices = self.dino.output_idx
self.feature_slices = list(
zip([0, *feature_slices[:-1]], feature_slices)
)
self.device = None
def to(self, *args):
self.device = args[0]
return super().to(*args)
def forward(self, images):
B, _, H, W = images.shape
current_pixels = H * W
target_pixels = min(self.config['inference']['max_pixels'],
max(self.config['inference']['min_pixels'], current_pixels))
factor = sqrt(target_pixels / current_pixels)
sphere_config = deepcopy(self.config['spherevit']['sphere'])
sphere_config['width'] *= factor
sphere_config['height'] *= factor
sphere = Sphere(config=sphere_config, device=self.device)
H_new = int(H * factor)
W_new = int(W * factor)
DINO_patch_size = 14 # please see the line 51 of `src/da2/model/dinov2/dinovit.py` (I know it's a little ugly to hardcode it here T_T)
H_new = ceil(H_new / DINO_patch_size) * DINO_patch_size
W_new = ceil(W_new / DINO_patch_size) * DINO_patch_size
images = F.interpolate(images, size=(H_new, W_new), mode='bilinear', align_corners=False)
images = TF.normalize(
images,
mean=IMAGENET_DATASET_MEAN,
std=IMAGENET_DATASET_STD,
)
print(f"DEBUG: images.dtype: {images.dtype}")
sphere_dirs = sphere.get_directions(shape=(H_new, W_new))
sphere_dirs = sphere_dirs.to(self.device)
sphere_dirs = sphere_dirs.to(dtype=images.dtype)
sphere_dirs = sphere_dirs.repeat(B, 1, 1, 1)
features = self.dino(images)
features = [
features[i:j][-1].contiguous()
for i, j in self.feature_slices
]
distance = self.vit_w_esphere(images, features, sphere_dirs)
distance = F.interpolate(distance, size=(H, W), mode='bilinear', align_corners=False)
distance = distance.squeeze(dim=1) # (b, 1, h, w) -> (b, h, w)
return distance