# Copyright (c) Meta Platforms, Inc. and affiliates. # # This software may be used and distributed in accordance with # the terms of the DINOv3 License Agreement. import os from enum import Enum from typing import Optional import torch import torch.nn as nn from .backbones import ( dinov3_vit7b16, Weights as BackboneWeights, convert_path_or_url_to_url, ) from .utils import DINOV3_BASE_URL class ClassifierWeights(Enum): IMAGENET1K = "IMAGENET1K" def _make_dinov3_linear_classification_head( *, backbone_name: str = "dinov3_vit7b16", embed_dim: int = 8192, pretrained: bool = True, classifier_weights: ClassifierWeights | str = ClassifierWeights.IMAGENET1K, check_hash: bool = False, **kwargs, ): linear_head = nn.Linear(embed_dim, 1_000) if pretrained: if type(classifier_weights) is ClassifierWeights: assert classifier_weights == ClassifierWeights.IMAGENET1K, ( f"Unsupported weights for linear classifier: {classifier_weights}" ) weights_name = classifier_weights.value.lower() hash = kwargs["hash"] if "hash" in kwargs else "90d8ed92" model_filename = f"{backbone_name}_{weights_name}_linear_head-{hash}.pth" url = os.path.join(DINOV3_BASE_URL, backbone_name, model_filename) else: url = convert_path_or_url_to_url(classifier_weights) state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu", check_hash=check_hash) linear_head.load_state_dict(state_dict, strict=True) return linear_head class _LinearClassifierWrapper(nn.Module): def __init__(self, *, backbone: nn.Module, linear_head: nn.Module): super().__init__() self.backbone = backbone self.linear_head = linear_head def forward(self, x): x = self.backbone.forward_features(x) cls_token = x["x_norm_clstoken"] patch_tokens = x["x_norm_patchtokens"] linear_input = torch.cat( [ cls_token, patch_tokens.mean(dim=1), ], dim=1, ) return self.linear_head(linear_input) def _make_dinov3_linear_classifier( *, backbone_name: str = "dinov3_vit7b16", pretrained: bool = True, classifier_weights: ClassifierWeights | str = ClassifierWeights.IMAGENET1K, backbone_weights: BackboneWeights | str = BackboneWeights.LVD1689M, check_hash: bool = False, **kwargs, ): if backbone_name == "dinov3_vit7b16": backbone = dinov3_vit7b16(pretrained=pretrained, weights=backbone_weights, check_hash=check_hash) else: raise AssertionError(f"Unsupported backbone: {backbone_name}, linear classifiers are provided only for ViT-7b") embed_dim = backbone.embed_dim linear_head = _make_dinov3_linear_classification_head( backbone_name=backbone_name, embed_dim=2 * embed_dim, pretrained=pretrained, classifier_weights=classifier_weights, **kwargs, ) return _LinearClassifierWrapper(backbone=backbone, linear_head=linear_head) def dinov3_vit7b16_lc( *, pretrained: bool = True, weights: ClassifierWeights | str = ClassifierWeights.IMAGENET1K, backbone_weights: BackboneWeights | str = BackboneWeights.LVD1689M, check_hash: bool = False, **kwargs, ): """ Linear classifier on top of a DINOv3 ViT-7B/16 backbone pretrained on the LVD-1689M dataset and trained on ImageNet-1k. """ return _make_dinov3_linear_classifier( backbone_name="dinov3_vit7b16", pretrained=pretrained, classifier_weights=weights, backbone_weights=backbone_weights, check_hash=check_hash, **kwargs, )