SPECTRE-Large / spectre /model.py
cclaess's picture
Initial commit
8b41845 verified
import torch
import torch.nn as nn
MODEL_CONFIGS = {
"spectre-small": {
"name": "spectre-small",
"backbone": "vit_small_patch16_128",
"backbone_checkpoint_path_or_url": None,
"backbone_kwargs": {
"num_classes": 0,
"global_pool": '',
"pos_embed": "rope",
"rope_kwargs": {"base": 1000.0},
"init_values": 1.0,
},
"feature_combiner": "feat_vit_small",
"feature_combiner_checkpoint_path_or_url": None,
"feature_combiner_kwargs": {
"num_classes": 0,
"global_pool": '',
"pos_embed": "rope",
"rope_kwargs": {"base": 100.0},
"init_values": 1.0,
},
"description": "SPECTRE model with ViT-Small backbone and feature combiner.",
}, # Pretrained/Distilled checkpoints will be added later
"spectre-base": {
"name": "spectre-base",
"backbone": "vit_base_patch16_128",
"backbone_checkpoint_path_or_url": None,
"backbone_kwargs": {
"num_classes": 0,
"global_pool": '',
"pos_embed": "rope",
"rope_kwargs": {"base": 1000.0},
"init_values": 1.0,
},
"feature_combiner": "feat_vit_base",
"feature_combiner_checkpoint_path_or_url": None,
"feature_combiner_kwargs": {
"num_classes": 0,
"global_pool": '',
"pos_embed": "rope",
"rope_kwargs": {"base": 100.0},
"init_values": 1.0,
},
"description": "SPECTRE model with ViT-Base backbone and feature combiner.",
}, # Pretrained/Distilled checkpoints will be added later
"spectre-large": {
"name": "spectre-large",
"backbone": "vit_large_patch16_128",
"backbone_checkpoint_path_or_url": None,
"backbone_kwargs": {
"num_classes": 0,
"global_pool": '',
"pos_embed": "rope",
"rope_kwargs": {"base": 1000.0},
"init_values": 1.0,
},
"feature_combiner": "feat_vit_large",
"feature_combiner_checkpoint_path_or_url": None,
"feature_combiner_kwargs": {
"num_classes": 0,
"global_pool": '',
"pos_embed": "rope",
"rope_kwargs": {"base": 100.0},
"init_values": 1.0,
},
"description": "SPECTRE model with ViT-Large backbone and feature combiner.",
},
"spectre-large-pretrained": {
"name": "spectre-large-pretrained",
"backbone": "vit_large_patch16_128",
"backbone_checkpoint_path_or_url": "https://huggingface.co/cclaess/SPECTRE/resolve/main/spectre_backbone_vit_large_patch16_128.pt?download=true",
"backbone_kwargs": {
"num_classes": 0,
"global_pool": '',
"pos_embed": "rope",
"rope_kwargs": {"base": 1000.0},
"init_values": 1.0,
},
"feature_combiner": "feat_vit_large",
"feature_combiner_checkpoint_path_or_url": "https://huggingface.co/cclaess/SPECTRE/resolve/main/spectre_combiner_feature_vit_large.pt?download=true",
"feature_combiner_kwargs": {
"num_classes": 0,
"global_pool": '',
"pos_embed": "rope",
"rope_kwargs": {"base": 100.0},
"init_values": 1.0,
},
"description": "Pretrained SPECTRE model with ViT-Large backbone and feature combiner.",
}
}
class SpectreImageFeatureExtractor(nn.Module):
def __init__(
self,
backbone_name: str,
backbone_kwargs: dict = {},
backbone_checkpoint_path_or_url: str | None = None,
feature_combiner_name: str | None = None,
feature_combiner_kwargs: dict = {},
feature_combiner_checkpoint_path_or_url: str | None = None,
**kwargs,
):
super().__init__()
self.backbone = None
self.feature_combiner = None
self._init_backbone(
backbone_name,
checkpoint_path_or_url=backbone_checkpoint_path_or_url,
**backbone_kwargs,
**kwargs,
)
if feature_combiner_name is not None:
self._init_feature_combiner(
feature_combiner_name,
checkpoint_path_or_url=feature_combiner_checkpoint_path_or_url,
**feature_combiner_kwargs,
**kwargs,
)
def _init_backbone(
self,
model_name: str,
checkpoint_path_or_url: str | None = None,
**kwargs
):
backbone_cls = getattr(__import__('spectre.models', fromlist=[model_name]), model_name)
self.backbone = backbone_cls(
checkpoint_path_or_url=checkpoint_path_or_url,
**kwargs,
)
def _init_feature_combiner(
self,
model_name: str,
checkpoint_path_or_url: str | None = None,
**kwargs,
):
if self.backbone.global_pool == '':
patch_dim = self.backbone.embed_dim * 2 # CLS + AVG pooled tokens
else:
patch_dim = self.backbone.embed_dim
feature_combiner_cls = getattr(__import__('spectre.models', fromlist=[model_name]), model_name)
self.feature_combiner = feature_combiner_cls(
patch_dim=patch_dim,
checkpoint_path_or_url=checkpoint_path_or_url,
**kwargs,
)
def extract_backbone_features(
self,
x: torch.Tensor,
):
"""
Extract features from the backbone for a batch of image sets. Input is expected to be of
shape (B, N, C, H, W, D), where B is the batch size, N is the number of image patches per
image, C is the number of channels, H is height, W is width, and D is depth.
The output will be a tensor of extracted features (B, N, T, F) where T is the number of
tokens and F is the feature dimension.
Args:
x (torch.Tensor): Input tensor of shape (B, N, C, H, W, D)
Returns:
torch.Tensor: Extracted features of shape (B, N, T, F)
"""
assert x.ndim == 6, "Input tensor must have 6 dimensions: (B, N, C, H, W, D)"
B, N, C, H, W, D = x.shape
x = x.view(B * N, C, H, W, D)
features = self.backbone(x)
if features.ndim == 2: # only CLS token
features = features.unsqueeze(1)
features = features.view(B, N, features.shape[1], -1)
return features
def combine_features(
self,
features: torch.Tensor,
grid_size: tuple[int, int, int],
):
"""
Combine features from multiple image patches using the feature combiner.
Args:
features (torch.Tensor): Input features of shape (B, N, T, F)
grid_size (tuple[int, int, int]): Grid size of the image patches
Returns:
torch.Tensor: Combined features of shape (B, T', F')
"""
_, N, T, _ = features.shape
assert features.ndim == 4, "Input features must have 4 dimensions: (B, N, T, F)"
assert N == grid_size[0] * grid_size[1] * grid_size[2], \
"Number of patches N must match the product of grid_size dimensions"
if T == 1: # only CLS token
features = features.squeeze(2)
else:
# We combine CLS tokens with AVG pooling of other tokens
features = torch.cat([
features[:, :, 0, :], # CLS token (B, N, F)
features[:, :, 1:, :].mean(dim=2) # AVG pooled tokens (B, N, F)
], dim=-1) # (B, N, 2F)
features = self.feature_combiner(features, grid_size) # (B, T', F')
return features
def forward(self, x, grid_size: tuple[int, int, int] | None = None):
features = self.extract_backbone_features(x)
if self.feature_combiner is not None:
assert grid_size is not None, \
"`grid_size` must be provided when using feature combiner"
features = self.combine_features(features, grid_size)
return features
@classmethod
def from_config(
cls,
config: dict,
**kwargs,
) -> 'SpectreImageFeatureExtractor':
model = cls(
backbone_name=config["backbone"],
backbone_checkpoint_path_or_url=config.get("backbone_checkpoint_path_or_url", None),
backbone_kwargs=config.get("backbone_kwargs", {}),
feature_combiner_name=config.get("feature_combiner", None),
feature_combiner_checkpoint_path_or_url=config.get("feature_combiner_checkpoint_path_or_url", None),
feature_combiner_kwargs=config.get("feature_combiner_kwargs", {}),
**kwargs,
)
return model