File size: 2,717 Bytes
7382c66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import torch
import torch.nn as nn
import torch.nn.functional as F
from copy import deepcopy
from math import (
    ceil, 
    sqrt
)
from huggingface_hub import PyTorchModelHubMixin
import torchvision.transforms.v2.functional as TF
from .dinov2 import DINOViT
from .vit_w_esphere import ViT_w_Esphere
from .sphere import Sphere


IMAGENET_DATASET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_DATASET_STD = (0.229, 0.224, 0.225)

class SphereViT(nn.Module, PyTorchModelHubMixin):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.dino = DINOViT()
        self.vit_w_esphere = ViT_w_Esphere(config['spherevit']['vit_w_esphere'])
        feature_slices = self.dino.output_idx
        self.feature_slices = list(
            zip([0, *feature_slices[:-1]], feature_slices)
        )
        self.device = None

    def to(self, *args):
        self.device = args[0]
        return super().to(*args)

    def forward(self, images):
        B, _, H, W = images.shape
        current_pixels = H * W
        target_pixels = min(self.config['inference']['max_pixels'], 
            max(self.config['inference']['min_pixels'], current_pixels))
        factor = sqrt(target_pixels / current_pixels)
        sphere_config = deepcopy(self.config['spherevit']['sphere'])
        sphere_config['width'] *= factor
        sphere_config['height'] *= factor
        sphere = Sphere(config=sphere_config, device=self.device)
        H_new = int(H * factor)
        W_new = int(W * factor)
        DINO_patch_size = 14 # please see the line 51 of `src/da2/model/dinov2/dinovit.py` (I know it's a little ugly to hardcode it here T_T)
        H_new = ceil(H_new / DINO_patch_size) * DINO_patch_size
        W_new = ceil(W_new / DINO_patch_size) * DINO_patch_size
        images = F.interpolate(images, size=(H_new, W_new), mode='bilinear', align_corners=False)
        images = TF.normalize(
            images,
            mean=IMAGENET_DATASET_MEAN,
            std=IMAGENET_DATASET_STD,
        )
        print(f"DEBUG: images.dtype: {images.dtype}")

        sphere_dirs = sphere.get_directions(shape=(H_new, W_new))
        sphere_dirs = sphere_dirs.to(self.device)
        sphere_dirs = sphere_dirs.to(dtype=images.dtype)
        sphere_dirs = sphere_dirs.repeat(B, 1, 1, 1)

        features = self.dino(images)
        features = [
            features[i:j][-1].contiguous()
            for i, j in self.feature_slices
        ]
        distance = self.vit_w_esphere(images, features, sphere_dirs)
        distance = F.interpolate(distance, size=(H, W), mode='bilinear', align_corners=False)
        distance = distance.squeeze(dim=1) # (b, 1, h, w) -> (b, h, w)
        return distance