|
|
|
|
|
|
|
|
|
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import logging |
|
|
import types |
|
|
from dataclasses import dataclass |
|
|
from typing import Dict, List, Literal, Optional |
|
|
|
|
|
import timm |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
from .vit import ( |
|
|
forward_features_eva_fixed, |
|
|
make_vit_b16_backbone, |
|
|
resize_patch_embed, |
|
|
resize_vit, |
|
|
) |
|
|
|
|
|
LOGGER = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
ViTPreset = Literal[ |
|
|
"dinov2l16_384", |
|
|
] |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class ViTConfig: |
|
|
"""Configuration for ViT.""" |
|
|
|
|
|
in_chans: int |
|
|
embed_dim: int |
|
|
|
|
|
img_size: int = 384 |
|
|
patch_size: int = 16 |
|
|
|
|
|
|
|
|
timm_preset: Optional[str] = None |
|
|
timm_img_size: int = 384 |
|
|
timm_patch_size: int = 16 |
|
|
|
|
|
|
|
|
encoder_feature_layer_ids: List[int] = None |
|
|
"""The layers in the Beit/ViT used to constructs encoder features for DPT.""" |
|
|
encoder_feature_dims: List[int] = None |
|
|
"""The dimension of features of encoder layers from Beit/ViT features for DPT.""" |
|
|
|
|
|
|
|
|
VIT_CONFIG_DICT: Dict[ViTPreset, ViTConfig] = { |
|
|
"dinov2l16_384": ViTConfig( |
|
|
in_chans=3, |
|
|
embed_dim=1024, |
|
|
encoder_feature_layer_ids=[5, 11, 17, 23], |
|
|
encoder_feature_dims=[256, 512, 1024, 1024], |
|
|
img_size=384, |
|
|
patch_size=16, |
|
|
timm_preset="vit_large_patch14_dinov2", |
|
|
timm_img_size=518, |
|
|
timm_patch_size=14, |
|
|
), |
|
|
} |
|
|
|
|
|
|
|
|
def create_vit( |
|
|
preset: ViTPreset, |
|
|
use_pretrained: bool = False, |
|
|
checkpoint_uri: str | None = None, |
|
|
use_grad_checkpointing: bool = False, |
|
|
) -> nn.Module: |
|
|
"""Create and load a VIT backbone module. |
|
|
|
|
|
Args: |
|
|
---- |
|
|
preset: The VIT preset to load the pre-defined config. |
|
|
use_pretrained: Load pretrained weights if True, default is False. |
|
|
checkpoint_uri: Checkpoint to load the wights from. |
|
|
use_grad_checkpointing: Use grandient checkpointing. |
|
|
|
|
|
Returns: |
|
|
------- |
|
|
A Torch ViT backbone module. |
|
|
|
|
|
""" |
|
|
config = VIT_CONFIG_DICT[preset] |
|
|
|
|
|
img_size = (config.img_size, config.img_size) |
|
|
patch_size = (config.patch_size, config.patch_size) |
|
|
|
|
|
if "eva02" in preset: |
|
|
model = timm.create_model(config.timm_preset, pretrained=use_pretrained) |
|
|
model.forward_features = types.MethodType(forward_features_eva_fixed, model) |
|
|
else: |
|
|
model = timm.create_model( |
|
|
config.timm_preset, pretrained=use_pretrained, dynamic_img_size=True |
|
|
) |
|
|
model = make_vit_b16_backbone( |
|
|
model, |
|
|
encoder_feature_dims=config.encoder_feature_dims, |
|
|
encoder_feature_layer_ids=config.encoder_feature_layer_ids, |
|
|
vit_features=config.embed_dim, |
|
|
use_grad_checkpointing=use_grad_checkpointing, |
|
|
) |
|
|
if config.patch_size != config.timm_patch_size: |
|
|
model.model = resize_patch_embed(model.model, new_patch_size=patch_size) |
|
|
if config.img_size != config.timm_img_size: |
|
|
model.model = resize_vit(model.model, img_size=img_size) |
|
|
|
|
|
if checkpoint_uri is not None: |
|
|
state_dict = torch.load(checkpoint_uri, map_location="cpu") |
|
|
missing_keys, unexpected_keys = model.load_state_dict( |
|
|
state_dict=state_dict, strict=False |
|
|
) |
|
|
|
|
|
if len(unexpected_keys) != 0: |
|
|
raise KeyError(f"Found unexpected keys when loading vit: {unexpected_keys}") |
|
|
if len(missing_keys) != 0: |
|
|
raise KeyError(f"Keys are missing when loading vit: {missing_keys}") |
|
|
|
|
|
LOGGER.info(model) |
|
|
return model.model |
|
|
|