File size: 2,005 Bytes
a47307e
 
 
 
 
 
 
 
 
 
 
 
abc71eb
 
a47307e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as T
from PIL import Image

class DinoVisionTransformerClassifier(nn.Module):
    def __init__(self, num_classes):
        super(DinoVisionTransformerClassifier, self).__init__()

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        # Workaround to bypass HTTP Error 403 rate limit exceeded 
        torch.hub._validate_not_a_forked_repo=lambda a,b,c: True
        self.model = torch.hub.load("facebookresearch/dinov2", "dinov2_vitb14_lc")
        
        self.model.linear_head = nn.Sequential(
            nn.Linear(3840, 512, bias=True), 
            nn.ReLU(),
            nn.Linear(512, 256, bias=True),
            nn.ReLU(),
            nn.Linear(256, num_classes, bias=True)
        )

        self.model.to(self.device)

        self.transform_image = T.Compose([
            T.Resize((224, 224)),
            T.ToTensor(),
            T.Normalize(mean=[0.485, 0.456, 0.406],
                                std=[0.229, 0.224, 0.225])
        ])

        self.model_name = "dinov2"

    def load_image_from_filepath(self, img: str) -> torch.Tensor:
        """
        Load an image as filepath and return a tensor that can be used as an input to model.
        """
        img = Image.open(img).convert('RGB')

        transformed_img = self.transform_image(img)[:3].unsqueeze(0).to(self.device)

        return transformed_img

    def load_image_from_pillowimage(self, img: Image.Image) -> torch.Tensor:
        """
        Load an image as Pillow Image and return a tensor that can be used as an input to model.
        """
        transformed_img = self.transform_image(img)[:3].unsqueeze(0).to(self.device)

        return transformed_img

    def forward(self, x):
        if isinstance(x, str):
            x = self.load_image_from_filepath(x)
        if isinstance(x, Image.Image):
            x = self.load_image_from_pillowimage(x)
        return self.model(x)