| |
| |
| |
| |
|
|
| from enum import Enum |
| from functools import partial |
| from typing import Optional, Tuple, Union |
|
|
| import torch |
|
|
| from .backbones import _make_dinov2_model |
| from .depth import BNHead, DepthEncoderDecoder, DPTHead |
| from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name, CenterPadding |
|
|
|
|
| class Weights(Enum): |
| NYU = "NYU" |
| KITTI = "KITTI" |
|
|
|
|
| def _get_depth_range(pretrained: bool, weights: Weights = Weights.NYU) -> Tuple[float, float]: |
| if not pretrained: |
| return (0.001, 10.0) |
|
|
| |
| if weights == Weights.KITTI: |
| return (0.001, 80.0) |
|
|
| if weights == Weights.NYU: |
| return (0.001, 10.0) |
|
|
| return (0.001, 10.0) |
|
|
|
|
| def _make_dinov2_linear_depth_head( |
| *, |
| embed_dim: int, |
| layers: int, |
| min_depth: float, |
| max_depth: float, |
| **kwargs, |
| ): |
| if layers not in (1, 4): |
| raise AssertionError(f"Unsupported number of layers: {layers}") |
|
|
| if layers == 1: |
| in_index = [0] |
| else: |
| assert layers == 4 |
| in_index = [0, 1, 2, 3] |
|
|
| return BNHead( |
| classify=True, |
| n_bins=256, |
| bins_strategy="UD", |
| norm_strategy="linear", |
| upsample=4, |
| in_channels=[embed_dim] * len(in_index), |
| in_index=in_index, |
| input_transform="resize_concat", |
| channels=embed_dim * len(in_index) * 2, |
| align_corners=False, |
| min_depth=0.001, |
| max_depth=80, |
| loss_decode=(), |
| ) |
|
|
|
|
| def _make_dinov2_linear_depther( |
| *, |
| arch_name: str = "vit_large", |
| layers: int = 4, |
| pretrained: bool = True, |
| weights: Union[Weights, str] = Weights.NYU, |
| depth_range: Optional[Tuple[float, float]] = None, |
| **kwargs, |
| ): |
| if layers not in (1, 4): |
| raise AssertionError(f"Unsupported number of layers: {layers}") |
| if isinstance(weights, str): |
| try: |
| weights = Weights[weights] |
| except KeyError: |
| raise AssertionError(f"Unsupported weights: {weights}") |
|
|
| if depth_range is None: |
| depth_range = _get_depth_range(pretrained, weights) |
| min_depth, max_depth = depth_range |
|
|
| backbone = _make_dinov2_model(arch_name=arch_name, pretrained=pretrained, **kwargs) |
|
|
| embed_dim = backbone.embed_dim |
| patch_size = backbone.patch_size |
| model_name = _make_dinov2_model_name(arch_name, patch_size) |
| linear_depth_head = _make_dinov2_linear_depth_head( |
| embed_dim=embed_dim, |
| layers=layers, |
| min_depth=min_depth, |
| max_depth=max_depth, |
| ) |
|
|
| layer_count = { |
| "vit_small": 12, |
| "vit_base": 12, |
| "vit_large": 24, |
| "vit_giant2": 40, |
| }[arch_name] |
|
|
| if layers == 4: |
| out_index = { |
| "vit_small": [2, 5, 8, 11], |
| "vit_base": [2, 5, 8, 11], |
| "vit_large": [4, 11, 17, 23], |
| "vit_giant2": [9, 19, 29, 39], |
| }[arch_name] |
| else: |
| assert layers == 1 |
| out_index = [layer_count - 1] |
|
|
| model = DepthEncoderDecoder(backbone=backbone, decode_head=linear_depth_head) |
| model.backbone.forward = partial( |
| backbone.get_intermediate_layers, |
| n=out_index, |
| reshape=True, |
| return_class_token=True, |
| norm=False, |
| ) |
| model.backbone.register_forward_pre_hook(lambda _, x: CenterPadding(patch_size)(x[0])) |
|
|
| if pretrained: |
| layers_str = str(layers) if layers == 4 else "" |
| weights_str = weights.value.lower() |
| url = _DINOV2_BASE_URL + f"/{model_name}/{model_name}_{weights_str}_linear{layers_str}_head.pth" |
| checkpoint = torch.hub.load_state_dict_from_url(url, map_location="cpu") |
| if "state_dict" in checkpoint: |
| state_dict = checkpoint["state_dict"] |
| model.load_state_dict(state_dict, strict=False) |
|
|
| return model |
|
|
|
|
| def dinov2_vits14_ld(*, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs): |
| return _make_dinov2_linear_depther( |
| arch_name="vit_small", layers=layers, pretrained=pretrained, weights=weights, **kwargs |
| ) |
|
|
|
|
| def dinov2_vitb14_ld(*, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs): |
| return _make_dinov2_linear_depther( |
| arch_name="vit_base", layers=layers, pretrained=pretrained, weights=weights, **kwargs |
| ) |
|
|
|
|
| def dinov2_vitl14_ld(*, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs): |
| return _make_dinov2_linear_depther( |
| arch_name="vit_large", layers=layers, pretrained=pretrained, weights=weights, **kwargs |
| ) |
|
|
|
|
| def dinov2_vitg14_ld(*, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs): |
| return _make_dinov2_linear_depther( |
| arch_name="vit_giant2", layers=layers, ffn_layer="swiglufused", pretrained=pretrained, weights=weights, **kwargs |
| ) |
|
|
|
|
| def _make_dinov2_dpt_depth_head(*, embed_dim: int, min_depth: float, max_depth: float): |
| return DPTHead( |
| in_channels=[embed_dim] * 4, |
| channels=256, |
| embed_dims=embed_dim, |
| post_process_channels=[embed_dim // 2 ** (3 - i) for i in range(4)], |
| readout_type="project", |
| min_depth=min_depth, |
| max_depth=max_depth, |
| loss_decode=(), |
| ) |
|
|
|
|
| def _make_dinov2_dpt_depther( |
| *, |
| arch_name: str = "vit_large", |
| pretrained: bool = True, |
| weights: Union[Weights, str] = Weights.NYU, |
| depth_range: Optional[Tuple[float, float]] = None, |
| **kwargs, |
| ): |
| if isinstance(weights, str): |
| try: |
| weights = Weights[weights] |
| except KeyError: |
| raise AssertionError(f"Unsupported weights: {weights}") |
|
|
| if depth_range is None: |
| depth_range = _get_depth_range(pretrained, weights) |
| min_depth, max_depth = depth_range |
|
|
| backbone = _make_dinov2_model(arch_name=arch_name, pretrained=pretrained, **kwargs) |
|
|
| model_name = _make_dinov2_model_name(arch_name, backbone.patch_size) |
| dpt_depth_head = _make_dinov2_dpt_depth_head(embed_dim=backbone.embed_dim, min_depth=min_depth, max_depth=max_depth) |
|
|
| out_index = { |
| "vit_small": [2, 5, 8, 11], |
| "vit_base": [2, 5, 8, 11], |
| "vit_large": [4, 11, 17, 23], |
| "vit_giant2": [9, 19, 29, 39], |
| }[arch_name] |
|
|
| model = DepthEncoderDecoder(backbone=backbone, decode_head=dpt_depth_head) |
| model.backbone.forward = partial( |
| backbone.get_intermediate_layers, |
| n=out_index, |
| reshape=True, |
| return_class_token=True, |
| norm=False, |
| ) |
| model.backbone.register_forward_pre_hook(lambda _, x: CenterPadding(backbone.patch_size)(x[0])) |
|
|
| if pretrained: |
| weights_str = weights.value.lower() |
| url = _DINOV2_BASE_URL + f"/{model_name}/{model_name}_{weights_str}_dpt_head.pth" |
| checkpoint = torch.hub.load_state_dict_from_url(url, map_location="cpu") |
| if "state_dict" in checkpoint: |
| state_dict = checkpoint["state_dict"] |
| model.load_state_dict(state_dict, strict=False) |
|
|
| return model |
|
|
|
|
| def dinov2_vits14_dd(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs): |
| return _make_dinov2_dpt_depther(arch_name="vit_small", pretrained=pretrained, weights=weights, **kwargs) |
|
|
|
|
| def dinov2_vitb14_dd(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs): |
| return _make_dinov2_dpt_depther(arch_name="vit_base", pretrained=pretrained, weights=weights, **kwargs) |
|
|
|
|
| def dinov2_vitl14_dd(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs): |
| return _make_dinov2_dpt_depther(arch_name="vit_large", pretrained=pretrained, weights=weights, **kwargs) |
|
|
|
|
| def dinov2_vitg14_dd(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs): |
| return _make_dinov2_dpt_depther( |
| arch_name="vit_giant2", ffn_layer="swiglufused", pretrained=pretrained, weights=weights, **kwargs |
| ) |
|
|