Feature Extraction
Transformers
Safetensors
English
spectre
medical-imaging
ct-scan
3d
vision-transformer
self-supervised-learning
foundation-model
radiology
custom_code
Instructions to use cclaess/SPECTRE-Large with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use cclaess/SPECTRE-Large with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("feature-extraction", model="cclaess/SPECTRE-Large", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("cclaess/SPECTRE-Large", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| import torch | |
| import torch.nn as nn | |
| MODEL_CONFIGS = { | |
| "spectre-small": { | |
| "name": "spectre-small", | |
| "backbone": "vit_small_patch16_128", | |
| "backbone_checkpoint_path_or_url": None, | |
| "backbone_kwargs": { | |
| "num_classes": 0, | |
| "global_pool": '', | |
| "pos_embed": "rope", | |
| "rope_kwargs": {"base": 1000.0}, | |
| "init_values": 1.0, | |
| }, | |
| "feature_combiner": "feat_vit_small", | |
| "feature_combiner_checkpoint_path_or_url": None, | |
| "feature_combiner_kwargs": { | |
| "num_classes": 0, | |
| "global_pool": '', | |
| "pos_embed": "rope", | |
| "rope_kwargs": {"base": 100.0}, | |
| "init_values": 1.0, | |
| }, | |
| "description": "SPECTRE model with ViT-Small backbone and feature combiner.", | |
| }, # Pretrained/Distilled checkpoints will be added later | |
| "spectre-base": { | |
| "name": "spectre-base", | |
| "backbone": "vit_base_patch16_128", | |
| "backbone_checkpoint_path_or_url": None, | |
| "backbone_kwargs": { | |
| "num_classes": 0, | |
| "global_pool": '', | |
| "pos_embed": "rope", | |
| "rope_kwargs": {"base": 1000.0}, | |
| "init_values": 1.0, | |
| }, | |
| "feature_combiner": "feat_vit_base", | |
| "feature_combiner_checkpoint_path_or_url": None, | |
| "feature_combiner_kwargs": { | |
| "num_classes": 0, | |
| "global_pool": '', | |
| "pos_embed": "rope", | |
| "rope_kwargs": {"base": 100.0}, | |
| "init_values": 1.0, | |
| }, | |
| "description": "SPECTRE model with ViT-Base backbone and feature combiner.", | |
| }, # Pretrained/Distilled checkpoints will be added later | |
| "spectre-large": { | |
| "name": "spectre-large", | |
| "backbone": "vit_large_patch16_128", | |
| "backbone_checkpoint_path_or_url": None, | |
| "backbone_kwargs": { | |
| "num_classes": 0, | |
| "global_pool": '', | |
| "pos_embed": "rope", | |
| "rope_kwargs": {"base": 1000.0}, | |
| "init_values": 1.0, | |
| }, | |
| "feature_combiner": "feat_vit_large", | |
| "feature_combiner_checkpoint_path_or_url": None, | |
| "feature_combiner_kwargs": { | |
| "num_classes": 0, | |
| "global_pool": '', | |
| "pos_embed": "rope", | |
| "rope_kwargs": {"base": 100.0}, | |
| "init_values": 1.0, | |
| }, | |
| "description": "SPECTRE model with ViT-Large backbone and feature combiner.", | |
| }, | |
| "spectre-large-pretrained": { | |
| "name": "spectre-large-pretrained", | |
| "backbone": "vit_large_patch16_128", | |
| "backbone_checkpoint_path_or_url": "https://huggingface.co/cclaess/SPECTRE/resolve/main/spectre_backbone_vit_large_patch16_128.pt?download=true", | |
| "backbone_kwargs": { | |
| "num_classes": 0, | |
| "global_pool": '', | |
| "pos_embed": "rope", | |
| "rope_kwargs": {"base": 1000.0}, | |
| "init_values": 1.0, | |
| }, | |
| "feature_combiner": "feat_vit_large", | |
| "feature_combiner_checkpoint_path_or_url": "https://huggingface.co/cclaess/SPECTRE/resolve/main/spectre_combiner_feature_vit_large.pt?download=true", | |
| "feature_combiner_kwargs": { | |
| "num_classes": 0, | |
| "global_pool": '', | |
| "pos_embed": "rope", | |
| "rope_kwargs": {"base": 100.0}, | |
| "init_values": 1.0, | |
| }, | |
| "description": "Pretrained SPECTRE model with ViT-Large backbone and feature combiner.", | |
| } | |
| } | |
| class SpectreImageFeatureExtractor(nn.Module): | |
| def __init__( | |
| self, | |
| backbone_name: str, | |
| backbone_kwargs: dict = {}, | |
| backbone_checkpoint_path_or_url: str | None = None, | |
| feature_combiner_name: str | None = None, | |
| feature_combiner_kwargs: dict = {}, | |
| feature_combiner_checkpoint_path_or_url: str | None = None, | |
| **kwargs, | |
| ): | |
| super().__init__() | |
| self.backbone = None | |
| self.feature_combiner = None | |
| self._init_backbone( | |
| backbone_name, | |
| checkpoint_path_or_url=backbone_checkpoint_path_or_url, | |
| **backbone_kwargs, | |
| **kwargs, | |
| ) | |
| if feature_combiner_name is not None: | |
| self._init_feature_combiner( | |
| feature_combiner_name, | |
| checkpoint_path_or_url=feature_combiner_checkpoint_path_or_url, | |
| **feature_combiner_kwargs, | |
| **kwargs, | |
| ) | |
| def _init_backbone( | |
| self, | |
| model_name: str, | |
| checkpoint_path_or_url: str | None = None, | |
| **kwargs | |
| ): | |
| backbone_cls = getattr(__import__('spectre.models', fromlist=[model_name]), model_name) | |
| self.backbone = backbone_cls( | |
| checkpoint_path_or_url=checkpoint_path_or_url, | |
| **kwargs, | |
| ) | |
| def _init_feature_combiner( | |
| self, | |
| model_name: str, | |
| checkpoint_path_or_url: str | None = None, | |
| **kwargs, | |
| ): | |
| if self.backbone.global_pool == '': | |
| patch_dim = self.backbone.embed_dim * 2 # CLS + AVG pooled tokens | |
| else: | |
| patch_dim = self.backbone.embed_dim | |
| feature_combiner_cls = getattr(__import__('spectre.models', fromlist=[model_name]), model_name) | |
| self.feature_combiner = feature_combiner_cls( | |
| patch_dim=patch_dim, | |
| checkpoint_path_or_url=checkpoint_path_or_url, | |
| **kwargs, | |
| ) | |
| def extract_backbone_features( | |
| self, | |
| x: torch.Tensor, | |
| ): | |
| """ | |
| Extract features from the backbone for a batch of image sets. Input is expected to be of | |
| shape (B, N, C, H, W, D), where B is the batch size, N is the number of image patches per | |
| image, C is the number of channels, H is height, W is width, and D is depth. | |
| The output will be a tensor of extracted features (B, N, T, F) where T is the number of | |
| tokens and F is the feature dimension. | |
| Args: | |
| x (torch.Tensor): Input tensor of shape (B, N, C, H, W, D) | |
| Returns: | |
| torch.Tensor: Extracted features of shape (B, N, T, F) | |
| """ | |
| assert x.ndim == 6, "Input tensor must have 6 dimensions: (B, N, C, H, W, D)" | |
| B, N, C, H, W, D = x.shape | |
| x = x.view(B * N, C, H, W, D) | |
| features = self.backbone(x) | |
| if features.ndim == 2: # only CLS token | |
| features = features.unsqueeze(1) | |
| features = features.view(B, N, features.shape[1], -1) | |
| return features | |
| def combine_features( | |
| self, | |
| features: torch.Tensor, | |
| grid_size: tuple[int, int, int], | |
| ): | |
| """ | |
| Combine features from multiple image patches using the feature combiner. | |
| Args: | |
| features (torch.Tensor): Input features of shape (B, N, T, F) | |
| grid_size (tuple[int, int, int]): Grid size of the image patches | |
| Returns: | |
| torch.Tensor: Combined features of shape (B, T', F') | |
| """ | |
| _, N, T, _ = features.shape | |
| assert features.ndim == 4, "Input features must have 4 dimensions: (B, N, T, F)" | |
| assert N == grid_size[0] * grid_size[1] * grid_size[2], \ | |
| "Number of patches N must match the product of grid_size dimensions" | |
| if T == 1: # only CLS token | |
| features = features.squeeze(2) | |
| else: | |
| # We combine CLS tokens with AVG pooling of other tokens | |
| features = torch.cat([ | |
| features[:, :, 0, :], # CLS token (B, N, F) | |
| features[:, :, 1:, :].mean(dim=2) # AVG pooled tokens (B, N, F) | |
| ], dim=-1) # (B, N, 2F) | |
| features = self.feature_combiner(features, grid_size) # (B, T', F') | |
| return features | |
| def forward(self, x, grid_size: tuple[int, int, int] | None = None): | |
| features = self.extract_backbone_features(x) | |
| if self.feature_combiner is not None: | |
| assert grid_size is not None, \ | |
| "`grid_size` must be provided when using feature combiner" | |
| features = self.combine_features(features, grid_size) | |
| return features | |
| def from_config( | |
| cls, | |
| config: dict, | |
| **kwargs, | |
| ) -> 'SpectreImageFeatureExtractor': | |
| model = cls( | |
| backbone_name=config["backbone"], | |
| backbone_checkpoint_path_or_url=config.get("backbone_checkpoint_path_or_url", None), | |
| backbone_kwargs=config.get("backbone_kwargs", {}), | |
| feature_combiner_name=config.get("feature_combiner", None), | |
| feature_combiner_checkpoint_path_or_url=config.get("feature_combiner_checkpoint_path_or_url", None), | |
| feature_combiner_kwargs=config.get("feature_combiner_kwargs", {}), | |
| **kwargs, | |
| ) | |
| return model | |