File size: 1,849 Bytes
28d6428
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
"""TIPSv2 image encoder for HuggingFace."""

from dataclasses import dataclass

import torch
from transformers import AutoConfig, AutoModel, PreTrainedModel

from .configuration_tips import TIPSv2ImageConfig
from .image_encoder import (
    VisionTransformer,
    vit_base,
    vit_giant2,
    vit_large,
    vit_small,
    vit_so400m,
)


MODEL_INIT_FUNCTIONS = {
    "vit_small": vit_small,
    "vit_base": vit_base,
    "vit_large": vit_large,
    "vit_so400m": vit_so400m,
    "vit_giant2": vit_giant2,
}


@dataclass
class TIPSv2ImageOutput:
    cls_token: torch.Tensor
    register_tokens: torch.Tensor
    patch_tokens: torch.Tensor


class TIPSv2ImageModel(PreTrainedModel):
    config_class = TIPSv2ImageConfig
    base_model_prefix = "model"
    all_tied_weights_keys = dict()

    def __init__(self, config: TIPSv2ImageConfig):
        super().__init__(config)

        if config.model_variant not in MODEL_INIT_FUNCTIONS:
            raise ValueError(
                f"Unknown model_variant={config.model_variant!r}. "
                f"Expected one of {list(MODEL_INIT_FUNCTIONS)}."
            )

        build_fn = MODEL_INIT_FUNCTIONS[config.model_variant]
        self.model: VisionTransformer = build_fn(
            image_size=config.image_size,
            patch_size=config.patch_size,
            ffn_layer=config.ffn_layer,
            init_values=config.init_values,
        )

    def forward(self, pixel_values: torch.Tensor) -> TIPSv2ImageOutput:
        cls_token, register_tokens, patch_tokens = self.model(pixel_values)
        return TIPSv2ImageOutput(
            cls_token=cls_token,
            register_tokens=register_tokens,
            patch_tokens=patch_tokens,
        )


AutoConfig.register("tipsv2", TIPSv2ImageConfig, exist_ok=True)
AutoModel.register(TIPSv2ImageConfig, TIPSv2ImageModel, exist_ok=True)