Prior2DSM / src /dinov3 /hub /classifiers.py
osherr's picture
Upload 222 files
bc90483 verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This software may be used and distributed in accordance with
# the terms of the DINOv3 License Agreement.
import os
from enum import Enum
from typing import Optional
import torch
import torch.nn as nn
from .backbones import (
dinov3_vit7b16,
Weights as BackboneWeights,
convert_path_or_url_to_url,
)
from .utils import DINOV3_BASE_URL
class ClassifierWeights(Enum):
IMAGENET1K = "IMAGENET1K"
def _make_dinov3_linear_classification_head(
*,
backbone_name: str = "dinov3_vit7b16",
embed_dim: int = 8192,
pretrained: bool = True,
classifier_weights: ClassifierWeights | str = ClassifierWeights.IMAGENET1K,
check_hash: bool = False,
**kwargs,
):
linear_head = nn.Linear(embed_dim, 1_000)
if pretrained:
if type(classifier_weights) is ClassifierWeights:
assert classifier_weights == ClassifierWeights.IMAGENET1K, (
f"Unsupported weights for linear classifier: {classifier_weights}"
)
weights_name = classifier_weights.value.lower()
hash = kwargs["hash"] if "hash" in kwargs else "90d8ed92"
model_filename = f"{backbone_name}_{weights_name}_linear_head-{hash}.pth"
url = os.path.join(DINOV3_BASE_URL, backbone_name, model_filename)
else:
url = convert_path_or_url_to_url(classifier_weights)
state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu", check_hash=check_hash)
linear_head.load_state_dict(state_dict, strict=True)
return linear_head
class _LinearClassifierWrapper(nn.Module):
def __init__(self, *, backbone: nn.Module, linear_head: nn.Module):
super().__init__()
self.backbone = backbone
self.linear_head = linear_head
def forward(self, x):
x = self.backbone.forward_features(x)
cls_token = x["x_norm_clstoken"]
patch_tokens = x["x_norm_patchtokens"]
linear_input = torch.cat(
[
cls_token,
patch_tokens.mean(dim=1),
],
dim=1,
)
return self.linear_head(linear_input)
def _make_dinov3_linear_classifier(
*,
backbone_name: str = "dinov3_vit7b16",
pretrained: bool = True,
classifier_weights: ClassifierWeights | str = ClassifierWeights.IMAGENET1K,
backbone_weights: BackboneWeights | str = BackboneWeights.LVD1689M,
check_hash: bool = False,
**kwargs,
):
if backbone_name == "dinov3_vit7b16":
backbone = dinov3_vit7b16(pretrained=pretrained, weights=backbone_weights, check_hash=check_hash)
else:
raise AssertionError(f"Unsupported backbone: {backbone_name}, linear classifiers are provided only for ViT-7b")
embed_dim = backbone.embed_dim
linear_head = _make_dinov3_linear_classification_head(
backbone_name=backbone_name,
embed_dim=2 * embed_dim,
pretrained=pretrained,
classifier_weights=classifier_weights,
**kwargs,
)
return _LinearClassifierWrapper(backbone=backbone, linear_head=linear_head)
def dinov3_vit7b16_lc(
*,
pretrained: bool = True,
weights: ClassifierWeights | str = ClassifierWeights.IMAGENET1K,
backbone_weights: BackboneWeights | str = BackboneWeights.LVD1689M,
check_hash: bool = False,
**kwargs,
):
"""
Linear classifier on top of a DINOv3 ViT-7B/16 backbone pretrained on the LVD-1689M dataset and trained on ImageNet-1k.
"""
return _make_dinov3_linear_classifier(
backbone_name="dinov3_vit7b16",
pretrained=pretrained,
classifier_weights=weights,
backbone_weights=backbone_weights,
check_hash=check_hash,
**kwargs,
)