File size: 9,068 Bytes
8b41845
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
import torch
import torch.nn as nn


MODEL_CONFIGS = {
    "spectre-small": {
        "name": "spectre-small",
        "backbone": "vit_small_patch16_128",
        "backbone_checkpoint_path_or_url": None,
        "backbone_kwargs": {
            "num_classes": 0,
            "global_pool": '',
            "pos_embed": "rope",
            "rope_kwargs": {"base": 1000.0},
            "init_values": 1.0,
        },
        "feature_combiner": "feat_vit_small",
        "feature_combiner_checkpoint_path_or_url": None,
        "feature_combiner_kwargs": {
            "num_classes": 0,
            "global_pool": '',
            "pos_embed": "rope",
            "rope_kwargs": {"base": 100.0},
            "init_values": 1.0,
        },
        "description": "SPECTRE model with ViT-Small backbone and feature combiner.",
    },  # Pretrained/Distilled checkpoints will be added later
    "spectre-base": {
        "name": "spectre-base",
        "backbone": "vit_base_patch16_128",
        "backbone_checkpoint_path_or_url": None,
        "backbone_kwargs": {
            "num_classes": 0,
            "global_pool": '',
            "pos_embed": "rope",
            "rope_kwargs": {"base": 1000.0},
            "init_values": 1.0,
        },
        "feature_combiner": "feat_vit_base",
        "feature_combiner_checkpoint_path_or_url": None,
        "feature_combiner_kwargs": {
            "num_classes": 0,
            "global_pool": '',
            "pos_embed": "rope",
            "rope_kwargs": {"base": 100.0},
            "init_values": 1.0,
        },
        "description": "SPECTRE model with ViT-Base backbone and feature combiner.",
    },  # Pretrained/Distilled checkpoints will be added later
    "spectre-large": {
        "name": "spectre-large",
        "backbone": "vit_large_patch16_128",
        "backbone_checkpoint_path_or_url": None,
        "backbone_kwargs": {
            "num_classes": 0,
            "global_pool": '',
            "pos_embed": "rope",
            "rope_kwargs": {"base": 1000.0},
            "init_values": 1.0,
        },
        "feature_combiner": "feat_vit_large",
        "feature_combiner_checkpoint_path_or_url": None,
        "feature_combiner_kwargs": {
            "num_classes": 0,
            "global_pool": '',
            "pos_embed": "rope",
            "rope_kwargs": {"base": 100.0},
            "init_values": 1.0,
        },
        "description": "SPECTRE model with ViT-Large backbone and feature combiner.",
    },
    "spectre-large-pretrained": {
        "name": "spectre-large-pretrained",
        "backbone": "vit_large_patch16_128",
        "backbone_checkpoint_path_or_url": "https://huggingface.co/cclaess/SPECTRE/resolve/main/spectre_backbone_vit_large_patch16_128.pt?download=true",
        "backbone_kwargs": {
            "num_classes": 0,
            "global_pool": '',
            "pos_embed": "rope",
            "rope_kwargs": {"base": 1000.0},
            "init_values": 1.0,
        },
        "feature_combiner": "feat_vit_large",
        "feature_combiner_checkpoint_path_or_url": "https://huggingface.co/cclaess/SPECTRE/resolve/main/spectre_combiner_feature_vit_large.pt?download=true",
        "feature_combiner_kwargs": {
            "num_classes": 0,
            "global_pool": '',
            "pos_embed": "rope",
            "rope_kwargs": {"base": 100.0},
            "init_values": 1.0,
        },
        "description": "Pretrained SPECTRE model with ViT-Large backbone and feature combiner.",
    }
}


class SpectreImageFeatureExtractor(nn.Module):
    def __init__(

        self, 

        backbone_name: str, 

        backbone_kwargs: dict = {},

        backbone_checkpoint_path_or_url: str | None = None,

        feature_combiner_name: str | None = None, 

        feature_combiner_kwargs: dict = {},

        feature_combiner_checkpoint_path_or_url: str | None = None,

        **kwargs,

    ):  
        super().__init__()
        self.backbone = None
        self.feature_combiner = None
        self._init_backbone(
            backbone_name,
            checkpoint_path_or_url=backbone_checkpoint_path_or_url,
            **backbone_kwargs,
            **kwargs,
        )
        if feature_combiner_name is not None:
            self._init_feature_combiner(
                feature_combiner_name,
                checkpoint_path_or_url=feature_combiner_checkpoint_path_or_url,
                **feature_combiner_kwargs,
                **kwargs,
            )

    def _init_backbone(

        self, 

        model_name: str,

        checkpoint_path_or_url: str | None = None,

        **kwargs

    ):
        backbone_cls = getattr(__import__('spectre.models', fromlist=[model_name]), model_name)
        self.backbone = backbone_cls(
            checkpoint_path_or_url=checkpoint_path_or_url, 
            **kwargs,
        )
        
    def _init_feature_combiner(

        self, 

        model_name: str,

        checkpoint_path_or_url: str | None = None,

        **kwargs,

    ):
        if self.backbone.global_pool == '':
            patch_dim = self.backbone.embed_dim * 2  # CLS + AVG pooled tokens
        else:
            patch_dim = self.backbone.embed_dim

        feature_combiner_cls = getattr(__import__('spectre.models', fromlist=[model_name]), model_name)
        self.feature_combiner = feature_combiner_cls(
            patch_dim=patch_dim, 
            checkpoint_path_or_url=checkpoint_path_or_url, 
            **kwargs,
        )

    def extract_backbone_features(

        self, 

        x: torch.Tensor,

    ):
        """

        Extract features from the backbone for a batch of image sets. Input is expected to be of 

        shape (B, N, C, H, W, D), where B is the batch size, N is the number of image patches per 

        image, C is the number of channels, H is height, W is width, and D is depth.

        The output will be a tensor of extracted features (B, N, T, F) where T is the number of 

        tokens and F is the feature dimension.



        Args:

            x (torch.Tensor): Input tensor of shape (B, N, C, H, W, D)

        Returns:

            torch.Tensor: Extracted features of shape (B, N, T, F)

        """
        assert x.ndim == 6, "Input tensor must have 6 dimensions: (B, N, C, H, W, D)"
        B, N, C, H, W, D = x.shape
        x = x.view(B * N, C, H, W, D)
        features = self.backbone(x)
        if features.ndim == 2:  # only CLS token
            features = features.unsqueeze(1)
        features = features.view(B, N, features.shape[1], -1)
        return features

    def combine_features(

        self, 

        features: torch.Tensor,

        grid_size: tuple[int, int, int],

    ):
        """

        Combine features from multiple image patches using the feature combiner.



        Args:

            features (torch.Tensor): Input features of shape (B, N, T, F)

            grid_size (tuple[int, int, int]): Grid size of the image patches

        Returns:

            torch.Tensor: Combined features of shape (B, T', F')

        """
        _, N, T, _ = features.shape
        assert features.ndim == 4, "Input features must have 4 dimensions: (B, N, T, F)"
        assert N == grid_size[0] * grid_size[1] * grid_size[2], \
            "Number of patches N must match the product of grid_size dimensions"

        if T == 1:  # only CLS token
            features = features.squeeze(2)
        else:
            # We combine CLS tokens with AVG pooling of other tokens
            features = torch.cat([
                features[:, :, 0, :],  # CLS token (B, N, F)
                features[:, :, 1:, :].mean(dim=2)  # AVG pooled tokens (B, N, F)
            ], dim=-1)  # (B, N, 2F)
        features = self.feature_combiner(features, grid_size)  # (B, T', F')
        return features

    def forward(self, x, grid_size: tuple[int, int, int] | None = None):
        features = self.extract_backbone_features(x)
        if self.feature_combiner is not None:
            assert grid_size is not None, \
                "`grid_size` must be provided when using feature combiner"
            features = self.combine_features(features, grid_size)
        return features

    @classmethod
    def from_config(

        cls, 

        config: dict,

        **kwargs,

    ) -> 'SpectreImageFeatureExtractor':
        
        model = cls(
            backbone_name=config["backbone"],
            backbone_checkpoint_path_or_url=config.get("backbone_checkpoint_path_or_url", None),
            backbone_kwargs=config.get("backbone_kwargs", {}),
            feature_combiner_name=config.get("feature_combiner", None),
            feature_combiner_checkpoint_path_or_url=config.get("feature_combiner_checkpoint_path_or_url", None),
            feature_combiner_kwargs=config.get("feature_combiner_kwargs", {}),
            **kwargs,
        )
        return model