|
|
|
|
|
|
|
|
|
|
| import os
|
| from enum import Enum
|
| from typing import Optional
|
|
|
| import torch
|
| import torch.nn as nn
|
|
|
| from .backbones import (
|
| dinov3_vit7b16,
|
| Weights as BackboneWeights,
|
| convert_path_or_url_to_url,
|
| )
|
|
|
| from .utils import DINOV3_BASE_URL
|
|
|
|
|
| class ClassifierWeights(Enum):
|
| IMAGENET1K = "IMAGENET1K"
|
|
|
|
|
| def _make_dinov3_linear_classification_head(
|
| *,
|
| backbone_name: str = "dinov3_vit7b16",
|
| embed_dim: int = 8192,
|
| pretrained: bool = True,
|
| classifier_weights: ClassifierWeights | str = ClassifierWeights.IMAGENET1K,
|
| check_hash: bool = False,
|
| **kwargs,
|
| ):
|
| linear_head = nn.Linear(embed_dim, 1_000)
|
| if pretrained:
|
| if type(classifier_weights) is ClassifierWeights:
|
| assert classifier_weights == ClassifierWeights.IMAGENET1K, (
|
| f"Unsupported weights for linear classifier: {classifier_weights}"
|
| )
|
| weights_name = classifier_weights.value.lower()
|
| hash = kwargs["hash"] if "hash" in kwargs else "90d8ed92"
|
| model_filename = f"{backbone_name}_{weights_name}_linear_head-{hash}.pth"
|
| url = os.path.join(DINOV3_BASE_URL, backbone_name, model_filename)
|
| else:
|
| url = convert_path_or_url_to_url(classifier_weights)
|
| state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu", check_hash=check_hash)
|
| linear_head.load_state_dict(state_dict, strict=True)
|
| return linear_head
|
|
|
|
|
| class _LinearClassifierWrapper(nn.Module):
|
| def __init__(self, *, backbone: nn.Module, linear_head: nn.Module):
|
| super().__init__()
|
| self.backbone = backbone
|
| self.linear_head = linear_head
|
|
|
| def forward(self, x):
|
| x = self.backbone.forward_features(x)
|
| cls_token = x["x_norm_clstoken"]
|
| patch_tokens = x["x_norm_patchtokens"]
|
| linear_input = torch.cat(
|
| [
|
| cls_token,
|
| patch_tokens.mean(dim=1),
|
| ],
|
| dim=1,
|
| )
|
| return self.linear_head(linear_input)
|
|
|
|
|
| def _make_dinov3_linear_classifier(
|
| *,
|
| backbone_name: str = "dinov3_vit7b16",
|
| pretrained: bool = True,
|
| classifier_weights: ClassifierWeights | str = ClassifierWeights.IMAGENET1K,
|
| backbone_weights: BackboneWeights | str = BackboneWeights.LVD1689M,
|
| check_hash: bool = False,
|
| **kwargs,
|
| ):
|
| if backbone_name == "dinov3_vit7b16":
|
| backbone = dinov3_vit7b16(pretrained=pretrained, weights=backbone_weights, check_hash=check_hash)
|
| else:
|
| raise AssertionError(f"Unsupported backbone: {backbone_name}, linear classifiers are provided only for ViT-7b")
|
| embed_dim = backbone.embed_dim
|
| linear_head = _make_dinov3_linear_classification_head(
|
| backbone_name=backbone_name,
|
| embed_dim=2 * embed_dim,
|
| pretrained=pretrained,
|
| classifier_weights=classifier_weights,
|
| **kwargs,
|
| )
|
| return _LinearClassifierWrapper(backbone=backbone, linear_head=linear_head)
|
|
|
|
|
| def dinov3_vit7b16_lc(
|
| *,
|
| pretrained: bool = True,
|
| weights: ClassifierWeights | str = ClassifierWeights.IMAGENET1K,
|
| backbone_weights: BackboneWeights | str = BackboneWeights.LVD1689M,
|
| check_hash: bool = False,
|
| **kwargs,
|
| ):
|
| """
|
| Linear classifier on top of a DINOv3 ViT-7B/16 backbone pretrained on the LVD-1689M dataset and trained on ImageNet-1k.
|
| """
|
| return _make_dinov3_linear_classifier(
|
| backbone_name="dinov3_vit7b16",
|
| pretrained=pretrained,
|
| classifier_weights=weights,
|
| backbone_weights=backbone_weights,
|
| check_hash=check_hash,
|
| **kwargs,
|
| )
|
|
|