|
|
"""Contains preset for ViT modules. |
|
|
|
|
|
For licensing see accompanying LICENSE file. |
|
|
Copyright (C) 2025 Apple Inc. All Rights Reserved. |
|
|
""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import dataclasses |
|
|
from typing import Literal |
|
|
|
|
|
ViTPreset = Literal["dinov2l16_384",] |
|
|
|
|
|
MLPMode = Literal["vanilla", "glu"] |
|
|
|
|
|
|
|
|
@dataclasses.dataclass |
|
|
class ViTConfig: |
|
|
"""Configuration for ViT.""" |
|
|
|
|
|
in_chans: int |
|
|
embed_dim: int |
|
|
depth: int |
|
|
num_heads: int |
|
|
init_values: float |
|
|
|
|
|
img_size: int = 384 |
|
|
patch_size: int = 16 |
|
|
|
|
|
num_classes: int = 21841 |
|
|
mlp_ratio: float = 4.0 |
|
|
drop_rate: float = 0.0 |
|
|
attn_drop_rate: float = 0.0 |
|
|
drop_path_rate: float = 0.0 |
|
|
qkv_bias: bool = True |
|
|
global_pool: str = "avg" |
|
|
|
|
|
|
|
|
mlp_mode: MLPMode = "vanilla" |
|
|
|
|
|
|
|
|
intermediate_features_ids: list[int] | None = None |
|
|
|
|
|
def asdict(self): |
|
|
"""Convenience method to convert the class to a dict.""" |
|
|
return dataclasses.asdict(self) |
|
|
|
|
|
|
|
|
VIT_CONFIG_DICT: dict[ViTPreset, ViTConfig] = { |
|
|
"dinov2l16_384": ViTConfig( |
|
|
in_chans=3, |
|
|
embed_dim=1024, |
|
|
depth=24, |
|
|
num_heads=16, |
|
|
init_values=1e-5, |
|
|
global_pool="", |
|
|
), |
|
|
} |
|
|
|