File size: 1,504 Bytes
c29babb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import timm
import torch
import torch.nn as nn
from PIL import Image
from timm.models.eva import Eva


class PerceptionEncoder(nn.Module):
    def __init__(
        self,
        model_name="vit_pe_core_large_patch14_336",
        img_size: None | int = None,
    ):
        super().__init__()

        if img_size is not None:
            dynamic_img_size = True

        self.backbone: Eva = timm.create_model(
            model_name,
            pretrained=True,
            dynamic_img_size=dynamic_img_size,
        )

        # Get model specific transforms (normalization, resize)
        data_config = timm.data.resolve_model_data_config(self.backbone)

        if img_size is not None:
            data_config["input_size"] = (3, img_size, img_size)

        self._preprocess = timm.data.create_transform(**data_config, is_training=False)

        # Remove head
        self.backbone.head = nn.Identity()

        self.features_dim = self.backbone.num_features

    def preprocess(self, image: Image.Image) -> torch.Tensor:
        return self._preprocess(image)

    def forward(self, inputs: torch.Tensor) -> torch.Tensor:
        return self.backbone(inputs)

    def get_features_dim(self) -> int:
        return self.features_dim


if __name__ == "__main__":
    import autorootcwd  # noqa: F401

    from src.config import Backbone
    from src.encoders._common import inference

    model = PerceptionEncoder(Backbone.PerceptionEncoder_B_p16_224.value, img_size=224)
    inference(model)