| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from copy import deepcopy |
| | from math import ( |
| | ceil, |
| | sqrt |
| | ) |
| | from huggingface_hub import PyTorchModelHubMixin |
| | import torchvision.transforms.v2.functional as TF |
| | from .dinov2 import DINOViT |
| | from .vit_w_esphere import ViT_w_Esphere |
| | from .sphere import Sphere |
| |
|
| |
|
| | IMAGENET_DATASET_MEAN = (0.485, 0.456, 0.406) |
| | IMAGENET_DATASET_STD = (0.229, 0.224, 0.225) |
| |
|
| | class SphereViT(nn.Module, PyTorchModelHubMixin): |
| | def __init__(self, config): |
| | super().__init__() |
| | self.config = config |
| | self.dino = DINOViT() |
| | self.vit_w_esphere = ViT_w_Esphere(config['spherevit']['vit_w_esphere']) |
| | feature_slices = self.dino.output_idx |
| | self.feature_slices = list( |
| | zip([0, *feature_slices[:-1]], feature_slices) |
| | ) |
| | self.device = None |
| |
|
| | def to(self, *args): |
| | self.device = args[0] |
| | return super().to(*args) |
| |
|
| | def forward(self, images): |
| | B, _, H, W = images.shape |
| | current_pixels = H * W |
| | target_pixels = min(self.config['inference']['max_pixels'], |
| | max(self.config['inference']['min_pixels'], current_pixels)) |
| | factor = sqrt(target_pixels / current_pixels) |
| | sphere_config = deepcopy(self.config['spherevit']['sphere']) |
| | sphere_config['width'] *= factor |
| | sphere_config['height'] *= factor |
| | sphere = Sphere(config=sphere_config, device=self.device) |
| | H_new = int(H * factor) |
| | W_new = int(W * factor) |
| | DINO_patch_size = 14 |
| | H_new = ceil(H_new / DINO_patch_size) * DINO_patch_size |
| | W_new = ceil(W_new / DINO_patch_size) * DINO_patch_size |
| | images = F.interpolate(images, size=(H_new, W_new), mode='bilinear', align_corners=False) |
| | images = TF.normalize( |
| | images, |
| | mean=IMAGENET_DATASET_MEAN, |
| | std=IMAGENET_DATASET_STD, |
| | ) |
| | print(f"DEBUG: images.dtype: {images.dtype}") |
| |
|
| | sphere_dirs = sphere.get_directions(shape=(H_new, W_new)) |
| | sphere_dirs = sphere_dirs.to(self.device) |
| | sphere_dirs = sphere_dirs.to(dtype=images.dtype) |
| | sphere_dirs = sphere_dirs.repeat(B, 1, 1, 1) |
| |
|
| | features = self.dino(images) |
| | features = [ |
| | features[i:j][-1].contiguous() |
| | for i, j in self.feature_slices |
| | ] |
| | distance = self.vit_w_esphere(images, features, sphere_dirs) |
| | distance = F.interpolate(distance, size=(H, W), mode='bilinear', align_corners=False) |
| | distance = distance.squeeze(dim=1) |
| | return distance |
| |
|