Spaces:
Running on Zero
Running on Zero
Commit ·
78bb21c
0
Parent(s):
init
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- InfiniDepth/__init__.py +1 -0
- InfiniDepth/model/__init__.py +9 -0
- InfiniDepth/model/block/__init__.py +1 -0
- InfiniDepth/model/block/common.py +43 -0
- InfiniDepth/model/block/config.py +5 -0
- InfiniDepth/model/block/convolution.py +229 -0
- InfiniDepth/model/block/implicit_decoder.py +179 -0
- InfiniDepth/model/block/pe.py +222 -0
- InfiniDepth/model/block/perceive_io.py +274 -0
- InfiniDepth/model/block/prompt_models/__init__.py +31 -0
- InfiniDepth/model/block/prompt_models/__pycache__/__init__.cpython-310.pyc +0 -0
- InfiniDepth/model/block/prompt_models/__pycache__/__init__.cpython-311.pyc +0 -0
- InfiniDepth/model/block/prompt_models/__pycache__/crossattn.cpython-310.pyc +0 -0
- InfiniDepth/model/block/prompt_models/__pycache__/diffattn.cpython-310.pyc +0 -0
- InfiniDepth/model/block/prompt_models/__pycache__/rope.cpython-310.pyc +0 -0
- InfiniDepth/model/block/prompt_models/__pycache__/rope.cpython-311.pyc +0 -0
- InfiniDepth/model/block/prompt_models/__pycache__/sam.cpython-310.pyc +0 -0
- InfiniDepth/model/block/prompt_models/__pycache__/sam.cpython-311.pyc +0 -0
- InfiniDepth/model/block/prompt_models/__pycache__/selfattn.cpython-310.pyc +0 -0
- InfiniDepth/model/block/prompt_models/__pycache__/selfattn.cpython-311.pyc +0 -0
- InfiniDepth/model/block/prompt_models/crossattn.py +164 -0
- InfiniDepth/model/block/prompt_models/rope.py +215 -0
- InfiniDepth/model/block/prompt_models/sam.py +126 -0
- InfiniDepth/model/block/prompt_models/selfattn.py +289 -0
- InfiniDepth/model/block/prompt_models/utils/__init__.py +1 -0
- InfiniDepth/model/block/prompt_models/utils/__pycache__/__init__.cpython-310.pyc +0 -0
- InfiniDepth/model/block/prompt_models/utils/__pycache__/__init__.cpython-311.pyc +0 -0
- InfiniDepth/model/block/prompt_models/utils/__pycache__/pe_utils.cpython-310.pyc +0 -0
- InfiniDepth/model/block/prompt_models/utils/__pycache__/pe_utils.cpython-311.pyc +0 -0
- InfiniDepth/model/block/prompt_models/utils/__pycache__/transformer.cpython-310.pyc +0 -0
- InfiniDepth/model/block/prompt_models/utils/__pycache__/transformer.cpython-311.pyc +0 -0
- InfiniDepth/model/block/prompt_models/utils/pe_utils.py +72 -0
- InfiniDepth/model/block/prompt_models/utils/transformer.py +250 -0
- InfiniDepth/model/block/rope.py +69 -0
- InfiniDepth/model/block/torchhub/README.md +3 -0
- InfiniDepth/model/block/torchhub/dinov3/.docstr.yaml +6 -0
- InfiniDepth/model/block/torchhub/dinov3/.github/workflows/lint.yaml +47 -0
- InfiniDepth/model/block/torchhub/dinov3/.gitignore +18 -0
- InfiniDepth/model/block/torchhub/dinov3/CODE_OF_CONDUCT.md +80 -0
- InfiniDepth/model/block/torchhub/dinov3/CONTRIBUTING.md +31 -0
- InfiniDepth/model/block/torchhub/dinov3/LICENSE.md +66 -0
- InfiniDepth/model/block/torchhub/dinov3/MODEL_CARD.md +432 -0
- InfiniDepth/model/block/torchhub/dinov3/README.md +734 -0
- InfiniDepth/model/block/torchhub/dinov3/conda.yaml +23 -0
- InfiniDepth/model/block/torchhub/dinov3/dinov3/__init__.py +6 -0
- InfiniDepth/model/block/torchhub/dinov3/dinov3/checkpointer/__init__.py +18 -0
- InfiniDepth/model/block/torchhub/dinov3/dinov3/checkpointer/checkpointer.py +349 -0
- InfiniDepth/model/block/torchhub/dinov3/dinov3/configs/__init__.py +16 -0
- InfiniDepth/model/block/torchhub/dinov3/dinov3/configs/config.py +222 -0
- InfiniDepth/model/block/torchhub/dinov3/dinov3/configs/ssl_default_config.yaml +205 -0
InfiniDepth/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""InfiniDepth package."""
|
InfiniDepth/model/__init__.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .registry import MODEL_REGISTRY, register_model
|
| 2 |
+
from .model import InfiniDepth, InfiniDepth_DC
|
| 3 |
+
|
| 4 |
+
__all__ = [
|
| 5 |
+
"MODEL_REGISTRY",
|
| 6 |
+
"register_model",
|
| 7 |
+
"InfiniDepth",
|
| 8 |
+
"InfiniDepth_DC",
|
| 9 |
+
]
|
InfiniDepth/model/block/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Core building blocks for InfiniDepth models."""
|
InfiniDepth/model/block/common.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
|
| 10 |
+
from typing import Type
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class MLPBlock(nn.Module):
|
| 14 |
+
def __init__(
|
| 15 |
+
self,
|
| 16 |
+
embedding_dim: int,
|
| 17 |
+
mlp_dim: int,
|
| 18 |
+
act: Type[nn.Module] = nn.GELU,
|
| 19 |
+
) -> None:
|
| 20 |
+
super().__init__()
|
| 21 |
+
self.lin1 = nn.Linear(embedding_dim, mlp_dim)
|
| 22 |
+
self.lin2 = nn.Linear(mlp_dim, embedding_dim)
|
| 23 |
+
self.act = act()
|
| 24 |
+
|
| 25 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 26 |
+
return self.lin2(self.act(self.lin1(x)))
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
# From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
|
| 30 |
+
# Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa
|
| 31 |
+
class LayerNorm2d(nn.Module):
|
| 32 |
+
def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
|
| 33 |
+
super().__init__()
|
| 34 |
+
self.weight = nn.Parameter(torch.ones(num_channels))
|
| 35 |
+
self.bias = nn.Parameter(torch.zeros(num_channels))
|
| 36 |
+
self.eps = eps
|
| 37 |
+
|
| 38 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 39 |
+
u = x.mean(1, keepdim=True)
|
| 40 |
+
s = (x - u).pow(2).mean(1, keepdim=True)
|
| 41 |
+
x = (x - u) / torch.sqrt(s + self.eps)
|
| 42 |
+
x = self.weight[:, None, None] * x + self.bias[:, None, None]
|
| 43 |
+
return x
|
InfiniDepth/model/block/config.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
dinov3_model_configs = {
|
| 2 |
+
"vitl16":{'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024], 'layer_idxs': [4, 11, 17, 23]},
|
| 3 |
+
"vith16plus": {'encoder': 'vith', 'features': 384, 'out_channels': [1280, 1280, 1280, 1280], 'layer_idxs': [7, 15, 23, 31]},
|
| 4 |
+
}
|
| 5 |
+
|
InfiniDepth/model/block/convolution.py
ADDED
|
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
from functools import partial
|
| 11 |
+
from typing import Callable
|
| 12 |
+
import collections
|
| 13 |
+
from torch import Tensor
|
| 14 |
+
from itertools import repeat
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def bilinear_sampler(input, coords, align_corners=True, padding_mode="border"):
|
| 18 |
+
r"""Sample a tensor using bilinear interpolation
|
| 19 |
+
|
| 20 |
+
`bilinear_sampler(input, coords)` samples a tensor :attr:`input` at
|
| 21 |
+
coordinates :attr:`coords` using bilinear interpolation. It is the same
|
| 22 |
+
as `torch.nn.functional.grid_sample()` but with a different coordinate
|
| 23 |
+
convention.
|
| 24 |
+
|
| 25 |
+
The input tensor is assumed to be of shape :math:`(B, C, H, W)`, where
|
| 26 |
+
:math:`B` is the batch size, :math:`C` is the number of channels,
|
| 27 |
+
:math:`H` is the height of the image, and :math:`W` is the width of the
|
| 28 |
+
image. The tensor :attr:`coords` of shape :math:`(B, H_o, W_o, 2)` is
|
| 29 |
+
interpreted as an array of 2D point coordinates :math:`(x_i,y_i)`.
|
| 30 |
+
|
| 31 |
+
Alternatively, the input tensor can be of size :math:`(B, C, T, H, W)`,
|
| 32 |
+
in which case sample points are triplets :math:`(t_i,x_i,y_i)`. Note
|
| 33 |
+
that in this case the order of the components is slightly different
|
| 34 |
+
from `grid_sample()`, which would expect :math:`(x_i,y_i,t_i)`.
|
| 35 |
+
|
| 36 |
+
If `align_corners` is `True`, the coordinate :math:`x` is assumed to be
|
| 37 |
+
in the range :math:`[0,W-1]`, with 0 corresponding to the center of the
|
| 38 |
+
left-most image pixel :math:`W-1` to the center of the right-most
|
| 39 |
+
pixel.
|
| 40 |
+
|
| 41 |
+
If `align_corners` is `False`, the coordinate :math:`x` is assumed to
|
| 42 |
+
be in the range :math:`[0,W]`, with 0 corresponding to the left edge of
|
| 43 |
+
the left-most pixel :math:`W` to the right edge of the right-most
|
| 44 |
+
pixel.
|
| 45 |
+
|
| 46 |
+
Similar conventions apply to the :math:`y` for the range
|
| 47 |
+
:math:`[0,H-1]` and :math:`[0,H]` and to :math:`t` for the range
|
| 48 |
+
:math:`[0,T-1]` and :math:`[0,T]`.
|
| 49 |
+
|
| 50 |
+
Args:
|
| 51 |
+
input (Tensor): batch of input images.
|
| 52 |
+
coords (Tensor): batch of coordinates.
|
| 53 |
+
align_corners (bool, optional): Coordinate convention. Defaults to `True`.
|
| 54 |
+
padding_mode (str, optional): Padding mode. Defaults to `"border"`.
|
| 55 |
+
|
| 56 |
+
Returns:
|
| 57 |
+
Tensor: sampled points.
|
| 58 |
+
"""
|
| 59 |
+
|
| 60 |
+
sizes = input.shape[2:]
|
| 61 |
+
|
| 62 |
+
assert len(sizes) in [2, 3]
|
| 63 |
+
|
| 64 |
+
if len(sizes) == 3:
|
| 65 |
+
# t x y -> x y t to match dimensions T H W in grid_sample
|
| 66 |
+
coords = coords[..., [1, 2, 0]]
|
| 67 |
+
|
| 68 |
+
if align_corners:
|
| 69 |
+
coords = coords * torch.tensor(
|
| 70 |
+
[2 / max(size - 1, 1) for size in reversed(sizes)], device=coords.device
|
| 71 |
+
)
|
| 72 |
+
else:
|
| 73 |
+
coords = coords * torch.tensor(
|
| 74 |
+
[2 / size for size in reversed(sizes)], device=coords.device
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
coords -= 1
|
| 78 |
+
|
| 79 |
+
return F.grid_sample(
|
| 80 |
+
input, coords, align_corners=align_corners, padding_mode=padding_mode
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def round_to_multiple_of_4(n):
|
| 85 |
+
return round(n / 4) * 4
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
class ResidualBlock(nn.Module):
|
| 90 |
+
def __init__(self, in_planes, planes, norm_fn="group", stride=1):
|
| 91 |
+
super(ResidualBlock, self).__init__()
|
| 92 |
+
|
| 93 |
+
self.conv1 = nn.Conv2d(
|
| 94 |
+
in_planes,
|
| 95 |
+
planes,
|
| 96 |
+
kernel_size=3,
|
| 97 |
+
padding=1,
|
| 98 |
+
stride=stride,
|
| 99 |
+
padding_mode="zeros",
|
| 100 |
+
)
|
| 101 |
+
self.conv2 = nn.Conv2d(
|
| 102 |
+
planes, planes, kernel_size=3, padding=1, padding_mode="zeros"
|
| 103 |
+
)
|
| 104 |
+
self.relu = nn.ReLU(inplace=True)
|
| 105 |
+
|
| 106 |
+
num_groups = planes // 8
|
| 107 |
+
|
| 108 |
+
if norm_fn == "group":
|
| 109 |
+
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
| 110 |
+
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
| 111 |
+
if not stride == 1:
|
| 112 |
+
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
| 113 |
+
|
| 114 |
+
elif norm_fn == "batch":
|
| 115 |
+
self.norm1 = nn.BatchNorm2d(planes)
|
| 116 |
+
self.norm2 = nn.BatchNorm2d(planes)
|
| 117 |
+
if not stride == 1:
|
| 118 |
+
self.norm3 = nn.BatchNorm2d(planes)
|
| 119 |
+
|
| 120 |
+
elif norm_fn == "instance":
|
| 121 |
+
self.norm1 = nn.InstanceNorm2d(planes)
|
| 122 |
+
self.norm2 = nn.InstanceNorm2d(planes)
|
| 123 |
+
if not stride == 1:
|
| 124 |
+
self.norm3 = nn.InstanceNorm2d(planes)
|
| 125 |
+
|
| 126 |
+
elif norm_fn == "none":
|
| 127 |
+
self.norm1 = nn.Sequential()
|
| 128 |
+
self.norm2 = nn.Sequential()
|
| 129 |
+
if not stride == 1:
|
| 130 |
+
self.norm3 = nn.Sequential()
|
| 131 |
+
|
| 132 |
+
if stride == 1:
|
| 133 |
+
self.downsample = None
|
| 134 |
+
|
| 135 |
+
else:
|
| 136 |
+
self.downsample = nn.Sequential(
|
| 137 |
+
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
def forward(self, x):
|
| 141 |
+
y = x
|
| 142 |
+
y = self.relu(self.norm1(self.conv1(y)))
|
| 143 |
+
y = self.relu(self.norm2(self.conv2(y)))
|
| 144 |
+
|
| 145 |
+
if self.downsample is not None:
|
| 146 |
+
x = self.downsample(x)
|
| 147 |
+
|
| 148 |
+
return self.relu(x + y)
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
class BasicEncoder(nn.Module):
|
| 152 |
+
def __init__(self, input_dim=3, output_dim=128, stride=4):
|
| 153 |
+
super(BasicEncoder, self).__init__()
|
| 154 |
+
self.stride = stride
|
| 155 |
+
self.norm_fn = "instance"
|
| 156 |
+
self.in_planes = output_dim // 2
|
| 157 |
+
self.norm1 = nn.InstanceNorm2d(self.in_planes)
|
| 158 |
+
self.norm2 = nn.InstanceNorm2d(output_dim * 2)
|
| 159 |
+
|
| 160 |
+
self.conv1 = nn.Conv2d(
|
| 161 |
+
input_dim,
|
| 162 |
+
self.in_planes,
|
| 163 |
+
kernel_size=7,
|
| 164 |
+
stride=2,
|
| 165 |
+
padding=3,
|
| 166 |
+
padding_mode="zeros",
|
| 167 |
+
)
|
| 168 |
+
self.relu1 = nn.ReLU(inplace=True)
|
| 169 |
+
self.layer1 = self._make_layer(output_dim // 2, stride=1)
|
| 170 |
+
self.layer2 = self._make_layer(output_dim // 4 * 3, stride=2)
|
| 171 |
+
self.layer3 = self._make_layer(output_dim, stride=2)
|
| 172 |
+
self.layer4 = self._make_layer(output_dim, stride=2)
|
| 173 |
+
|
| 174 |
+
self.conv2 = nn.Conv2d(
|
| 175 |
+
output_dim * 3 + output_dim // 4,
|
| 176 |
+
output_dim * 2,
|
| 177 |
+
kernel_size=3,
|
| 178 |
+
padding=1,
|
| 179 |
+
padding_mode="zeros",
|
| 180 |
+
)
|
| 181 |
+
self.relu2 = nn.ReLU(inplace=True)
|
| 182 |
+
self.conv3 = nn.Conv2d(output_dim * 2, output_dim, kernel_size=1)
|
| 183 |
+
for m in self.modules():
|
| 184 |
+
if isinstance(m, nn.Conv2d):
|
| 185 |
+
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
| 186 |
+
elif isinstance(m, (nn.InstanceNorm2d)):
|
| 187 |
+
if m.weight is not None:
|
| 188 |
+
nn.init.constant_(m.weight, 1)
|
| 189 |
+
if m.bias is not None:
|
| 190 |
+
nn.init.constant_(m.bias, 0)
|
| 191 |
+
|
| 192 |
+
def _make_layer(self, dim, stride=1):
|
| 193 |
+
layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
|
| 194 |
+
layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
|
| 195 |
+
layers = (layer1, layer2)
|
| 196 |
+
|
| 197 |
+
self.in_planes = dim
|
| 198 |
+
return nn.Sequential(*layers)
|
| 199 |
+
|
| 200 |
+
def forward(self, x):
|
| 201 |
+
_, _, H, W = x.shape
|
| 202 |
+
|
| 203 |
+
x = self.conv1(x)
|
| 204 |
+
x = self.norm1(x)
|
| 205 |
+
x = self.relu1(x)
|
| 206 |
+
|
| 207 |
+
a = self.layer1(x)
|
| 208 |
+
b = self.layer2(a)
|
| 209 |
+
c = self.layer3(b)
|
| 210 |
+
d = self.layer4(c)
|
| 211 |
+
|
| 212 |
+
def _bilinear_intepolate(x):
|
| 213 |
+
return F.interpolate(
|
| 214 |
+
x,
|
| 215 |
+
(H // self.stride, W // self.stride),
|
| 216 |
+
mode="bilinear",
|
| 217 |
+
align_corners=True,
|
| 218 |
+
)
|
| 219 |
+
|
| 220 |
+
a = _bilinear_intepolate(a)
|
| 221 |
+
b = _bilinear_intepolate(b)
|
| 222 |
+
c = _bilinear_intepolate(c)
|
| 223 |
+
d = _bilinear_intepolate(d)
|
| 224 |
+
|
| 225 |
+
x = self.conv2(torch.cat([a, b, c, d], dim=1))
|
| 226 |
+
x = self.norm2(x)
|
| 227 |
+
x = self.relu2(x)
|
| 228 |
+
x = self.conv3(x)
|
| 229 |
+
return x
|
InfiniDepth/model/block/implicit_decoder.py
ADDED
|
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import enum
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
import sys
|
| 7 |
+
from grpc import insecure_channel
|
| 8 |
+
from sympy import use
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def exists(val):
|
| 13 |
+
return val is not None
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def default(val, d):
|
| 17 |
+
return val if exists(val) else d
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class MLP(nn.Module):
|
| 21 |
+
def __init__(self, in_dim, out_dim, hidden_list, output_act='elu'):
|
| 22 |
+
super().__init__()
|
| 23 |
+
layers = []
|
| 24 |
+
lastv = in_dim
|
| 25 |
+
for hidden in hidden_list:
|
| 26 |
+
layers += [nn.Linear(lastv, hidden), nn.ReLU()]
|
| 27 |
+
lastv = hidden
|
| 28 |
+
|
| 29 |
+
if out_dim is not None:
|
| 30 |
+
layers.append(nn.Linear(lastv, out_dim))
|
| 31 |
+
act = {
|
| 32 |
+
"sigmoid": nn.Sigmoid(),
|
| 33 |
+
"relu": nn.ReLU(),
|
| 34 |
+
"elu": nn.ELU(),
|
| 35 |
+
}.get(output_act, nn.Identity())
|
| 36 |
+
layers.append(act)
|
| 37 |
+
|
| 38 |
+
self.layers = nn.Sequential(*layers)
|
| 39 |
+
|
| 40 |
+
def forward(self, x):
|
| 41 |
+
return self.layers(x)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class ImplicitHead(nn.Module):
|
| 45 |
+
"""
|
| 46 |
+
Implicit head that fuses DINOv3 semantic features and BasicEncoder low-level features.
|
| 47 |
+
|
| 48 |
+
Args:
|
| 49 |
+
hidden_dim: DINOv3 feature dimension (e.g., 1024)
|
| 50 |
+
basic_dim: BasicEncoder feature dimension (e.g., 128)
|
| 51 |
+
fusion_type: Feature fusion strategy
|
| 52 |
+
- "concat": Simple concatenation
|
| 53 |
+
- "cross_attn": Cross-attention between features
|
| 54 |
+
- "gated": Gated fusion with learnable weights
|
| 55 |
+
out_dim: Output dimension (1 for depth)
|
| 56 |
+
hidden_list: MLP hidden layer dimensions
|
| 57 |
+
"""
|
| 58 |
+
def __init__(
|
| 59 |
+
self,
|
| 60 |
+
hidden_dim, # 1024 for DINOv3
|
| 61 |
+
basic_dim=128, # BasicEncoder output dim
|
| 62 |
+
fusion_type="gated", # concat, gated
|
| 63 |
+
out_dim=1,
|
| 64 |
+
hidden_list=[1024, 256, 32],
|
| 65 |
+
):
|
| 66 |
+
|
| 67 |
+
super().__init__()
|
| 68 |
+
self.hidden_dim = hidden_dim
|
| 69 |
+
self.basic_dim = basic_dim
|
| 70 |
+
self.fusion_type = fusion_type
|
| 71 |
+
|
| 72 |
+
# Determine input dimension based on fusion type
|
| 73 |
+
if fusion_type == "concat":
|
| 74 |
+
# Simple concatenation
|
| 75 |
+
in_channels = hidden_dim + basic_dim
|
| 76 |
+
elif fusion_type == "gated":
|
| 77 |
+
# Gated fusion with learnable weights
|
| 78 |
+
self.gate_proj = nn.Linear(basic_dim, hidden_dim)
|
| 79 |
+
self.gate = nn.Sequential(
|
| 80 |
+
nn.Linear(hidden_dim * 2, hidden_dim),
|
| 81 |
+
nn.Sigmoid()
|
| 82 |
+
)
|
| 83 |
+
in_channels = hidden_dim
|
| 84 |
+
else:
|
| 85 |
+
raise ValueError(f"Unknown fusion_type: {fusion_type}")
|
| 86 |
+
|
| 87 |
+
self.out_layer = MLP(
|
| 88 |
+
in_dim=in_channels,
|
| 89 |
+
out_dim=out_dim,
|
| 90 |
+
hidden_list=hidden_list,
|
| 91 |
+
output_act='elu'
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
def _encode_feat(self, features, patch_h, patch_w):
|
| 95 |
+
"""Extract DINOv3 feature map."""
|
| 96 |
+
x = features[-1][0]
|
| 97 |
+
out_feat = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w))
|
| 98 |
+
return out_feat
|
| 99 |
+
|
| 100 |
+
def _decode_dpt(self, feat, basic_feat, coord):
|
| 101 |
+
"""
|
| 102 |
+
Query features at given coordinates and fuse them.
|
| 103 |
+
|
| 104 |
+
Args:
|
| 105 |
+
feat: DINOv3 feature map [B, hidden_dim, H_dino, W_dino]
|
| 106 |
+
basic_feat: BasicEncoder feature map [B, basic_dim, H_basic, W_basic]
|
| 107 |
+
coord: Query coordinates [B, N, 2] in range [-1, 1]
|
| 108 |
+
|
| 109 |
+
Returns:
|
| 110 |
+
pred: Predicted depth [B, N, 1]
|
| 111 |
+
"""
|
| 112 |
+
coord_ = coord.clone()
|
| 113 |
+
coord_.clamp_(-1 + 1e-6, 1 - 1e-6)
|
| 114 |
+
|
| 115 |
+
# Sample DINOv3 features at query coordinates
|
| 116 |
+
q_feat_dino = F.grid_sample(
|
| 117 |
+
feat, coord_.flip(-1).unsqueeze(1),
|
| 118 |
+
mode='bilinear', align_corners=False
|
| 119 |
+
)[:, :, 0, :].permute(0, 2, 1) # [B, N, hidden_dim]
|
| 120 |
+
|
| 121 |
+
# Sample BasicEncoder features at query coordinates (if available)
|
| 122 |
+
if basic_feat is not None:
|
| 123 |
+
q_feat_basic = F.grid_sample(
|
| 124 |
+
basic_feat, coord_.flip(-1).unsqueeze(1),
|
| 125 |
+
mode='bilinear', align_corners=False
|
| 126 |
+
)[:, :, 0, :].permute(0, 2, 1) # [B, N, basic_dim]
|
| 127 |
+
|
| 128 |
+
# Fuse features based on fusion type
|
| 129 |
+
q_feat_fused = self._fuse_features(q_feat_dino, q_feat_basic)
|
| 130 |
+
else:
|
| 131 |
+
# If no basic features, use only DINOv3
|
| 132 |
+
q_feat_fused = q_feat_dino
|
| 133 |
+
|
| 134 |
+
# Predict depth
|
| 135 |
+
pred = self.out_layer(q_feat_fused)
|
| 136 |
+
return pred
|
| 137 |
+
|
| 138 |
+
def _fuse_features(self, feat_dino, feat_basic):
|
| 139 |
+
"""
|
| 140 |
+
Fuse DINOv3 and BasicEncoder features.
|
| 141 |
+
|
| 142 |
+
Args:
|
| 143 |
+
feat_dino: [B, N, hidden_dim]
|
| 144 |
+
feat_basic: [B, N, basic_dim]
|
| 145 |
+
|
| 146 |
+
Returns:
|
| 147 |
+
fused_feat: [B, N, fused_dim]
|
| 148 |
+
"""
|
| 149 |
+
if self.fusion_type == "concat":
|
| 150 |
+
# Simple concatenation
|
| 151 |
+
return torch.cat([feat_dino, feat_basic], dim=-1)
|
| 152 |
+
|
| 153 |
+
elif self.fusion_type == "gated":
|
| 154 |
+
# Gated fusion with learnable weights
|
| 155 |
+
feat_basic_proj = self.gate_proj(feat_basic) # [B, N, hidden_dim]
|
| 156 |
+
gate_input = torch.cat([feat_dino, feat_basic_proj], dim=-1)
|
| 157 |
+
gate_weights = self.gate(gate_input) # [B, N, hidden_dim]
|
| 158 |
+
return gate_weights * feat_dino + (1 - gate_weights) * feat_basic_proj
|
| 159 |
+
|
| 160 |
+
def forward(self, features, basic_feat, patch_h, patch_w, coords):
|
| 161 |
+
"""
|
| 162 |
+
Forward pass.
|
| 163 |
+
|
| 164 |
+
Args:
|
| 165 |
+
features: DINOv3 features from backbone
|
| 166 |
+
basic_feat: BasicEncoder features [B, basic_dim, H/4, W/4]
|
| 167 |
+
patch_h, patch_w: DINOv3 feature map spatial size
|
| 168 |
+
coords: Query coordinates [B, N, 2]
|
| 169 |
+
|
| 170 |
+
Returns:
|
| 171 |
+
dpt_pred: Predicted depth [B, N, 1]
|
| 172 |
+
"""
|
| 173 |
+
# Extract DINOv3 feature map
|
| 174 |
+
feat = self._encode_feat(features, patch_h, patch_w) # [B, hidden_dim, H/14, W/14]
|
| 175 |
+
|
| 176 |
+
# Query and fuse features at coordinates
|
| 177 |
+
dpt_pred = self._decode_dpt(feat, basic_feat, coords)
|
| 178 |
+
|
| 179 |
+
return dpt_pred
|
InfiniDepth/model/block/pe.py
ADDED
|
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import math
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from typing import Any, Optional, Tuple, Dict
|
| 9 |
+
|
| 10 |
+
if torch.cuda.is_available():
|
| 11 |
+
acc_dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16
|
| 12 |
+
else:
|
| 13 |
+
acc_dtype = torch.float16
|
| 14 |
+
|
| 15 |
+
POS_EMB_REGISTRY = {}
|
| 16 |
+
|
| 17 |
+
def register_pos_emb(name):
|
| 18 |
+
def decorator(cls):
|
| 19 |
+
POS_EMB_REGISTRY[name.lower()] = cls
|
| 20 |
+
return cls
|
| 21 |
+
return decorator
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@register_pos_emb("dct")
|
| 25 |
+
class DctPositionEmbedding(nn.Module):
|
| 26 |
+
"""
|
| 27 |
+
Only supports 2D separable DCT encoding for query coordinates coords: [B, N, 2]:
|
| 28 |
+
Phi(x,y)[fx,fy] = cos(pi * fx * x) * cos(pi * fy * y) * 1/(1+fx*fy)
|
| 29 |
+
Convention: coords should be in the range [0, 1].
|
| 30 |
+
"""
|
| 31 |
+
def __init__(self, max_freqs: int = 8):
|
| 32 |
+
super().__init__()
|
| 33 |
+
self.max_freqs = max_freqs
|
| 34 |
+
|
| 35 |
+
freqs = torch.arange(max_freqs).float() # [F] -> 0..F-1
|
| 36 |
+
fx = freqs.view(-1, 1) # [F,1]
|
| 37 |
+
fy = freqs.view(1, -1) # [1,F]
|
| 38 |
+
coeffs = (1.0 + fx * fy) ** -1 # [F,F]
|
| 39 |
+
|
| 40 |
+
self.register_buffer("_freqs_1d", freqs, persistent=False)
|
| 41 |
+
self.register_buffer("_coeffs_2d", coeffs, persistent=False)
|
| 42 |
+
|
| 43 |
+
def forward(self, coords: torch.Tensor) -> torch.Tensor:
|
| 44 |
+
"""
|
| 45 |
+
coords: [B, N, 2], value range should be [0,1]
|
| 46 |
+
return: [B, N, F^2]
|
| 47 |
+
"""
|
| 48 |
+
assert coords.dim() == 3 and coords.size(-1) == 2, "coords must be [B, N, 2]"
|
| 49 |
+
B, N, _ = coords.shape
|
| 50 |
+
device, dtype = coords.device, coords.dtype
|
| 51 |
+
|
| 52 |
+
freqs = self._freqs_1d.to(device=device, dtype=dtype) # [F]
|
| 53 |
+
coeffs = self._coeffs_2d.to(device=device, dtype=dtype) # [F,F]
|
| 54 |
+
F = freqs.numel() # frequency dimension = max_freqs
|
| 55 |
+
|
| 56 |
+
x = coords[..., 0:1] # [B,N,1]
|
| 57 |
+
y = coords[..., 1:2] # [B,N,1]
|
| 58 |
+
dct_x = torch.cos(math.pi * x * freqs.view(1, 1, F)) # [B,N,F]
|
| 59 |
+
dct_y = torch.cos(math.pi * y * freqs.view(1, 1, F)) # [B,N,F]
|
| 60 |
+
|
| 61 |
+
out = dct_x.unsqueeze(-1) * dct_y.unsqueeze(-2) # [B,N,F,F]
|
| 62 |
+
out = out * coeffs.view(1, 1, F, F) # [B,N,F,F]
|
| 63 |
+
dct_emb = out.reshape(B, N, F * F) # [B,N,F^2]
|
| 64 |
+
|
| 65 |
+
return dct_emb # [B,N,F^2]
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
@register_pos_emb("random")
|
| 69 |
+
class RandomPositionEmbedding(nn.Module):
|
| 70 |
+
"""
|
| 71 |
+
Positional encoding using random spatial frequencies.
|
| 72 |
+
"""
|
| 73 |
+
|
| 74 |
+
patch_size = 14
|
| 75 |
+
|
| 76 |
+
def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None, image_pe_method: str = "patch") -> None:
|
| 77 |
+
super().__init__()
|
| 78 |
+
if scale is None or scale <= 0.0:
|
| 79 |
+
scale = 1.0
|
| 80 |
+
self.register_buffer(
|
| 81 |
+
"positional_encoding_gaussian_matrix",
|
| 82 |
+
scale * torch.randn((2, num_pos_feats)),
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
self.image_pe_method = image_pe_method
|
| 86 |
+
if self.image_pe_method == "image":
|
| 87 |
+
# self.patch_embed = nn.Conv2d(num_pos_feats*2, num_pos_feats*2, kernel_size=self.patch_size, stride=self.patch_size)
|
| 88 |
+
self.patch_embed = nn.Sequential(
|
| 89 |
+
nn.Conv2d(num_pos_feats * 2, num_pos_feats // 2, kernel_size=2, stride=2),
|
| 90 |
+
nn.ReLU(),
|
| 91 |
+
nn.Conv2d(
|
| 92 |
+
num_pos_feats // 2,
|
| 93 |
+
num_pos_feats * 2,
|
| 94 |
+
kernel_size=self.patch_size // 2,
|
| 95 |
+
stride=self.patch_size // 2,
|
| 96 |
+
),
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
|
| 100 |
+
"""Positionally encode points that are normalized to [0,1]."""
|
| 101 |
+
# assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
|
| 102 |
+
coords = 2 * coords - 1 # [0,1] --> [-1,1], equivalent to align_corners=False after this transform
|
| 103 |
+
coords = coords @ self.positional_encoding_gaussian_matrix
|
| 104 |
+
coords = 2 * np.pi * coords
|
| 105 |
+
# outputs d_1 x ... x d_n x C shape
|
| 106 |
+
return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
|
| 107 |
+
|
| 108 |
+
def forward_encoding(self, size: Tuple[int, int]) -> torch.Tensor:
|
| 109 |
+
"""Generate positional encoding for a grid of the specified size."""
|
| 110 |
+
h, w = size
|
| 111 |
+
device: Any = self.positional_encoding_gaussian_matrix.device
|
| 112 |
+
grid = torch.ones((h, w), device=device, dtype=torch.float32)
|
| 113 |
+
y_embed = grid.cumsum(dim=0) - 0.5
|
| 114 |
+
x_embed = grid.cumsum(dim=1) - 0.5
|
| 115 |
+
y_embed = y_embed / h
|
| 116 |
+
x_embed = x_embed / w
|
| 117 |
+
pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1)) # HxWx2 -> HxWxC
|
| 118 |
+
return pe.permute(2, 0, 1) # C x H x W
|
| 119 |
+
|
| 120 |
+
def forward(self, size: Tuple[int, int]) -> torch.Tensor:
|
| 121 |
+
if self.image_pe_method == "patch":
|
| 122 |
+
return self.forward_encoding(size)
|
| 123 |
+
elif self.image_pe_method == "image":
|
| 124 |
+
pe_encoding = self.forward_encoding(size)
|
| 125 |
+
pe_encoding_high = self.forward_encoding((size[0] * self.patch_size, size[1] * self.patch_size))
|
| 126 |
+
return pe_encoding + self.patch_embed(pe_encoding_high[None])[0]
|
| 127 |
+
|
| 128 |
+
def forward_with_coords(self, coords_input: torch.Tensor, image_size: Tuple[int, int]) -> torch.Tensor:
|
| 129 |
+
"""Positionally encode points that are not normalized to [0,1]."""
|
| 130 |
+
coords = coords_input.clone()
|
| 131 |
+
coords[:, :, 0] = coords[:, :, 0] / image_size[1]
|
| 132 |
+
coords[:, :, 1] = coords[:, :, 1] / image_size[0]
|
| 133 |
+
return self._pe_encoding(coords.to(torch.float)) # B x N x C
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
@register_pos_emb("rope")
|
| 137 |
+
class RoPEPositionEmbedding(nn.Module):
|
| 138 |
+
"""2D Rotary Position Embedding with support for continuous coordinates.
|
| 139 |
+
|
| 140 |
+
For each coordinate p (can be float), directly compute θ = p * inv_freq, then derive cos/sin.
|
| 141 |
+
"""
|
| 142 |
+
def __init__(self, frequency: float = 100.0, scaling_factor: float = 1.0):
|
| 143 |
+
super().__init__()
|
| 144 |
+
self.base_frequency = frequency
|
| 145 |
+
self.scaling_factor = scaling_factor
|
| 146 |
+
# Cache the inv_freq vector: key = feature_dim
|
| 147 |
+
self._inv_freq_cache: Dict[int, torch.Tensor] = {}
|
| 148 |
+
|
| 149 |
+
def _get_inv_freq(self, dim: int, device: torch.device, dtype: torch.dtype):
|
| 150 |
+
"""
|
| 151 |
+
Computes frequency components for rotary embeddings.
|
| 152 |
+
Returns an inv_freq vector of length dim/2, in the form 1 / base_freq^(2i/d)
|
| 153 |
+
"""
|
| 154 |
+
if dim not in self._inv_freq_cache:
|
| 155 |
+
# Use frequencies on even dimensions only
|
| 156 |
+
exponents = torch.arange(0, dim, 2, device=device).float() / dim
|
| 157 |
+
inv_freq = 1.0 / (self.base_frequency ** exponents)
|
| 158 |
+
self._inv_freq_cache[dim] = inv_freq.to(dtype)
|
| 159 |
+
return self._inv_freq_cache[dim]
|
| 160 |
+
|
| 161 |
+
@staticmethod
|
| 162 |
+
def _rotate_features(x: torch.Tensor) -> torch.Tensor:
|
| 163 |
+
"""Rotation: split [u0, u1, u2, u3,...] into two halves and concatenate (-v, u)."""
|
| 164 |
+
D = x.shape[-1]
|
| 165 |
+
x1, x2 = x[..., : D//2], x[..., D//2 :]
|
| 166 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 167 |
+
|
| 168 |
+
def _apply_1d_rope_continuous(
|
| 169 |
+
self,
|
| 170 |
+
x: torch.Tensor, # [B, N, d_half]
|
| 171 |
+
pos: torch.Tensor, # [B, N] floating-point coordinates
|
| 172 |
+
inv_freq: torch.Tensor # [d_half]
|
| 173 |
+
) -> torch.Tensor:
|
| 174 |
+
# 1) Compute angles: [B, N, d_half] = outer(pos, inv_freq)
|
| 175 |
+
# pos.unsqueeze(-1): [B, N, 1], inv_freq.unsqueeze(0): [1, d_half]
|
| 176 |
+
angles = pos.unsqueeze(-1) * inv_freq.unsqueeze(0)
|
| 177 |
+
# 2) Duplicate to double dimension: [B, N, d_half*2]
|
| 178 |
+
angles = torch.cat([angles, angles], dim=-1)
|
| 179 |
+
|
| 180 |
+
# 3) Compute cos/sin and expand to [B, N, D]
|
| 181 |
+
cos = angles.cos()
|
| 182 |
+
sin = angles.sin()
|
| 183 |
+
|
| 184 |
+
# 4) Apply rotation
|
| 185 |
+
return x * cos + self._rotate_features(x) * sin
|
| 186 |
+
|
| 187 |
+
def forward(self, tokens: torch.Tensor, positions: torch.Tensor) -> torch.Tensor:
|
| 188 |
+
"""
|
| 189 |
+
tokens: [B, N, dim]
|
| 190 |
+
positions: [B, N, 2] continuous coords: (y,x)
|
| 191 |
+
"""
|
| 192 |
+
B, N, D = tokens.shape
|
| 193 |
+
assert D % 2 == 0, "Feature dimension must be even"
|
| 194 |
+
|
| 195 |
+
assert positions.shape == (B, N, 2), "positions must be [B, N, 2]"
|
| 196 |
+
|
| 197 |
+
# Allocate half of the features to each direction
|
| 198 |
+
d_half = D // 2
|
| 199 |
+
|
| 200 |
+
# Get the inv_freq vector
|
| 201 |
+
inv_freq = self._get_inv_freq(d_half, tokens.device, tokens.dtype) # [d_half]
|
| 202 |
+
# Split feature dimension into first and second halves
|
| 203 |
+
tok_v, tok_h = tokens[..., :d_half], tokens[..., d_half:]
|
| 204 |
+
|
| 205 |
+
# Apply RoPE separately on y and x directions, positions[0]--> y, positions[1]--> x
|
| 206 |
+
out_v = self._apply_1d_rope_continuous(tok_v, positions[..., 0], inv_freq)
|
| 207 |
+
out_h = self._apply_1d_rope_continuous(tok_h, positions[..., 1], inv_freq)
|
| 208 |
+
|
| 209 |
+
return torch.cat([out_v, out_h], dim=-1)
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
def build_pos_emb(pos_emb_type="nerf", **kwargs):
|
| 213 |
+
pos_emb_type = pos_emb_type.lower()
|
| 214 |
+
if pos_emb_type not in POS_EMB_REGISTRY:
|
| 215 |
+
raise ValueError(f"Unknown pos_emb_type: {pos_emb_type}")
|
| 216 |
+
return POS_EMB_REGISTRY[pos_emb_type](**kwargs)
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
|
InfiniDepth/model/block/perceive_io.py
ADDED
|
@@ -0,0 +1,274 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Callable, Optional, Union
|
| 2 |
+
import torch
|
| 3 |
+
from einops import rearrange
|
| 4 |
+
from ...utils.logger import Log
|
| 5 |
+
from torch import Tensor, nn
|
| 6 |
+
|
| 7 |
+
try:
|
| 8 |
+
from xformers.ops import memory_efficient_attention, unbind
|
| 9 |
+
|
| 10 |
+
XFORMERS_AVAILABLE = True
|
| 11 |
+
except ImportError:
|
| 12 |
+
Log.warning("xFormers not available")
|
| 13 |
+
XFORMERS_AVAILABLE = False
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class CrossAttention(nn.Module):
|
| 17 |
+
def __init__(
|
| 18 |
+
self,
|
| 19 |
+
dim: int,
|
| 20 |
+
context_dim: int,
|
| 21 |
+
num_heads: int = 8,
|
| 22 |
+
qkv_bias: bool = False,
|
| 23 |
+
proj_bias: bool = True,
|
| 24 |
+
attn_drop: float = 0.0,
|
| 25 |
+
proj_drop: float = 0.0,
|
| 26 |
+
pe: str = "normal",
|
| 27 |
+
) -> None:
|
| 28 |
+
super().__init__()
|
| 29 |
+
self.num_heads = num_heads
|
| 30 |
+
head_dim = dim // num_heads
|
| 31 |
+
self.scale = head_dim**-0.5
|
| 32 |
+
|
| 33 |
+
self.qkv = nn.Linear(dim, dim, bias=qkv_bias)
|
| 34 |
+
self.qkv_context = nn.Linear(context_dim, context_dim * 2, bias=qkv_bias)
|
| 35 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 36 |
+
self.proj = nn.Linear(dim, dim, bias=proj_bias)
|
| 37 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 38 |
+
self.pe = pe
|
| 39 |
+
if self.pe == "qk":
|
| 40 |
+
self.norm1 = nn.LayerNorm(dim // num_heads)
|
| 41 |
+
self.norm2 = nn.LayerNorm(dim // num_heads)
|
| 42 |
+
|
| 43 |
+
def forward(self, x: Tensor, context: Tensor) -> Tensor:
|
| 44 |
+
# x is the query tensor, context is the key/value tensor
|
| 45 |
+
B, N, C = x.shape
|
| 46 |
+
_, M, _ = context.shape
|
| 47 |
+
|
| 48 |
+
qkv_x = self.qkv(x).reshape(B, N, 1, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
| 49 |
+
q_x = qkv_x[0] * self.scale
|
| 50 |
+
|
| 51 |
+
qkv_context = (
|
| 52 |
+
self.qkv_context(context).reshape(B, M, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
| 53 |
+
)
|
| 54 |
+
k_context, v_context = qkv_context[0], qkv_context[0]
|
| 55 |
+
|
| 56 |
+
# Cross-attention: query from x and key/value from context
|
| 57 |
+
attn = q_x @ k_context.transpose(-2, -1)
|
| 58 |
+
attn = attn.softmax(dim=-1)
|
| 59 |
+
attn = self.attn_drop(attn)
|
| 60 |
+
|
| 61 |
+
x = (attn @ v_context).transpose(1, 2).reshape(B, N, C)
|
| 62 |
+
x = self.proj(x)
|
| 63 |
+
x = self.proj_drop(x)
|
| 64 |
+
return x
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class MemEffCrossAttention(CrossAttention):
|
| 68 |
+
def forward(
|
| 69 |
+
self, x: Tensor, context: Tensor, x_pe: Tensor = None, context_pe: Tensor = None, attn_bias=None
|
| 70 |
+
) -> Tensor:
|
| 71 |
+
if not XFORMERS_AVAILABLE:
|
| 72 |
+
assert attn_bias is None, "xFormers is required for nested tensors usage"
|
| 73 |
+
return super().forward(x, context)
|
| 74 |
+
|
| 75 |
+
B, N, C = x.shape
|
| 76 |
+
_, M, C_context = context.shape
|
| 77 |
+
|
| 78 |
+
qkv_x = self.qkv(x).reshape(B, N, 1, self.num_heads, C // self.num_heads)
|
| 79 |
+
(q_x,) = unbind(qkv_x, 2)
|
| 80 |
+
|
| 81 |
+
qkv_context = self.qkv_context(context).reshape(B, M, 2, self.num_heads, C_context // self.num_heads)
|
| 82 |
+
k_context, v_context = unbind(qkv_context, 2)
|
| 83 |
+
|
| 84 |
+
if self.pe == "qk":
|
| 85 |
+
q_x = self.norm1(q_x + rearrange(x_pe, "b n (m c) -> b n m c", m=self.num_heads))
|
| 86 |
+
k_context = self.norm2(k_context + rearrange(context_pe, "b n (m c) -> b n m c", m=self.num_heads))
|
| 87 |
+
elif self.pe == "apply":
|
| 88 |
+
pass
|
| 89 |
+
# Memory-efficient cross-attention
|
| 90 |
+
x = memory_efficient_attention(
|
| 91 |
+
q_x.to(dtype=v_context.dtype), k_context.to(dtype=v_context.dtype), v_context, attn_bias=attn_bias
|
| 92 |
+
)
|
| 93 |
+
x = x.reshape([B, N, C])
|
| 94 |
+
|
| 95 |
+
x = self.proj(x)
|
| 96 |
+
x = self.proj_drop(x)
|
| 97 |
+
return x
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
# class Attention(nn.Module):
|
| 101 |
+
# def __init__(
|
| 102 |
+
# self,
|
| 103 |
+
# dim: int,
|
| 104 |
+
# num_heads: int = 8,
|
| 105 |
+
# qkv_bias: bool = False,
|
| 106 |
+
# proj_bias: bool = True,
|
| 107 |
+
# attn_drop: float = 0.0,
|
| 108 |
+
# proj_drop: float = 0.0,
|
| 109 |
+
# ) -> None:
|
| 110 |
+
# super().__init__()
|
| 111 |
+
# self.num_heads = num_heads
|
| 112 |
+
# head_dim = dim // num_heads
|
| 113 |
+
# self.scale = head_dim**-0.5
|
| 114 |
+
|
| 115 |
+
# self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 116 |
+
# self.attn_drop = nn.Dropout(attn_drop)
|
| 117 |
+
# self.proj = nn.Linear(dim, dim, bias=proj_bias)
|
| 118 |
+
# self.proj_drop = nn.Dropout(proj_drop)
|
| 119 |
+
|
| 120 |
+
# def forward(self, x: Tensor) -> Tensor:
|
| 121 |
+
# B, N, C = x.shape
|
| 122 |
+
# qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
| 123 |
+
|
| 124 |
+
# q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
|
| 125 |
+
# attn = q @ k.transpose(-2, -1)
|
| 126 |
+
|
| 127 |
+
# attn = attn.softmax(dim=-1)
|
| 128 |
+
# attn = self.attn_drop(attn)
|
| 129 |
+
|
| 130 |
+
# x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
| 131 |
+
# x = self.proj(x)
|
| 132 |
+
# x = self.proj_drop(x)
|
| 133 |
+
# return x
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
class Attention(nn.Module):
|
| 137 |
+
def __init__(
|
| 138 |
+
self,
|
| 139 |
+
dim: int,
|
| 140 |
+
num_heads: int = 8,
|
| 141 |
+
qkv_bias: bool = False,
|
| 142 |
+
proj_bias: bool = True,
|
| 143 |
+
attn_drop: float = 0.0,
|
| 144 |
+
proj_drop: float = 0.0,
|
| 145 |
+
pe: str = "normal",
|
| 146 |
+
) -> None:
|
| 147 |
+
super().__init__()
|
| 148 |
+
self.num_heads = num_heads
|
| 149 |
+
head_dim = dim // num_heads
|
| 150 |
+
self.scale = head_dim**-0.5
|
| 151 |
+
|
| 152 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 153 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 154 |
+
self.proj = nn.Linear(dim, dim, bias=proj_bias)
|
| 155 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 156 |
+
self.pe = pe
|
| 157 |
+
if self.pe == "qk":
|
| 158 |
+
self.norm1 = nn.LayerNorm(dim // num_heads)
|
| 159 |
+
self.norm2 = nn.LayerNorm(dim // num_heads)
|
| 160 |
+
|
| 161 |
+
def forward(self, x: Tensor, x_pe: Tensor = None) -> Tensor:
|
| 162 |
+
B, N, C = x.shape
|
| 163 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
| 164 |
+
|
| 165 |
+
q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
|
| 166 |
+
if self.pe == "qk":
|
| 167 |
+
q = self.norm1(q + x_pe)
|
| 168 |
+
k = self.norm2(k + x_pe)
|
| 169 |
+
elif self.pe == "apply":
|
| 170 |
+
pass
|
| 171 |
+
attn = q @ k.transpose(-2, -1)
|
| 172 |
+
|
| 173 |
+
attn = attn.softmax(dim=-1)
|
| 174 |
+
attn = self.attn_drop(attn)
|
| 175 |
+
|
| 176 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
| 177 |
+
x = self.proj(x)
|
| 178 |
+
x = self.proj_drop(x)
|
| 179 |
+
return x
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
class MemEffAttention(Attention):
|
| 183 |
+
def forward(
|
| 184 |
+
self,
|
| 185 |
+
x: Tensor,
|
| 186 |
+
x_pe: Tensor = None,
|
| 187 |
+
attn_bias=None,
|
| 188 |
+
) -> Tensor:
|
| 189 |
+
if not XFORMERS_AVAILABLE:
|
| 190 |
+
assert attn_bias is None, "xFormers is required for nested tensors usage"
|
| 191 |
+
return super().forward(x)
|
| 192 |
+
|
| 193 |
+
B, N, C = x.shape
|
| 194 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
|
| 195 |
+
q, k, v = unbind(qkv, 2)
|
| 196 |
+
if self.pe == "qk":
|
| 197 |
+
q = self.norm1(q + rearrange(x_pe, "b n (m c) -> b n m c", m=self.num_heads))
|
| 198 |
+
k = self.norm2(k + rearrange(x_pe, "b n (m c) -> b n m c", m=self.num_heads))
|
| 199 |
+
elif self.pe == "apply":
|
| 200 |
+
pass
|
| 201 |
+
# this is important
|
| 202 |
+
# as q, k after norm1/norm2 have different dtype
|
| 203 |
+
# which will cause error in memory_efficient_attention
|
| 204 |
+
x = memory_efficient_attention(q.to(dtype=v.dtype), k.to(dtype=v.dtype), v, attn_bias=attn_bias)
|
| 205 |
+
x = x.reshape([B, N, C])
|
| 206 |
+
|
| 207 |
+
x = self.proj(x)
|
| 208 |
+
x = self.proj_drop(x)
|
| 209 |
+
return x
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
class Mlp(nn.Module):
|
| 213 |
+
def __init__(
|
| 214 |
+
self,
|
| 215 |
+
in_features: int,
|
| 216 |
+
hidden_features: Optional[int] = None,
|
| 217 |
+
out_features: Optional[int] = None,
|
| 218 |
+
act_layer: Callable[..., nn.Module] = nn.GELU,
|
| 219 |
+
drop: float = 0.0,
|
| 220 |
+
bias: bool = True,
|
| 221 |
+
) -> None:
|
| 222 |
+
super().__init__()
|
| 223 |
+
out_features = out_features or in_features
|
| 224 |
+
hidden_features = hidden_features or in_features
|
| 225 |
+
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
|
| 226 |
+
self.act = act_layer()
|
| 227 |
+
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
|
| 228 |
+
self.drop = nn.Dropout(drop)
|
| 229 |
+
|
| 230 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 231 |
+
x = self.fc1(x)
|
| 232 |
+
x = self.act(x)
|
| 233 |
+
x = self.drop(x)
|
| 234 |
+
x = self.fc2(x)
|
| 235 |
+
x = self.drop(x)
|
| 236 |
+
return x
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
class LayerScale(nn.Module):
|
| 240 |
+
def __init__(
|
| 241 |
+
self,
|
| 242 |
+
dim: int,
|
| 243 |
+
init_values: Union[float, Tensor] = 1e-5,
|
| 244 |
+
inplace: bool = False,
|
| 245 |
+
) -> None:
|
| 246 |
+
super().__init__()
|
| 247 |
+
self.inplace = inplace
|
| 248 |
+
self.gamma = nn.Parameter(init_values * torch.ones(dim))
|
| 249 |
+
|
| 250 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 251 |
+
return x.mul_(self.gamma) if self.inplace else x * self.gamma
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
def drop_path(x, drop_prob: float = 0.0, training: bool = False):
|
| 255 |
+
if drop_prob == 0.0 or not training:
|
| 256 |
+
return x
|
| 257 |
+
keep_prob = 1 - drop_prob
|
| 258 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
| 259 |
+
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
| 260 |
+
if keep_prob > 0.0:
|
| 261 |
+
random_tensor.div_(keep_prob)
|
| 262 |
+
output = x * random_tensor
|
| 263 |
+
return output
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
class DropPath(nn.Module):
|
| 267 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
| 268 |
+
|
| 269 |
+
def __init__(self, drop_prob=None):
|
| 270 |
+
super().__init__()
|
| 271 |
+
self.drop_prob = drop_prob
|
| 272 |
+
|
| 273 |
+
def forward(self, x):
|
| 274 |
+
return drop_path(x, self.drop_prob, self.training)
|
InfiniDepth/model/block/prompt_models/__init__.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
from .sam import SAMPromptModel
|
| 3 |
+
from .selfattn import SelfAttnPromptModel
|
| 4 |
+
|
| 5 |
+
__all__ = [
|
| 6 |
+
"GeneralPromptModel",
|
| 7 |
+
"SelfAttnPromptModel",
|
| 8 |
+
"SAMPromptModel",
|
| 9 |
+
]
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class GeneralPromptModel(nn.Module):
|
| 13 |
+
def __init__(self, prompt_stage=[3], **kwargs):
|
| 14 |
+
super().__init__()
|
| 15 |
+
self.prompt_stage = prompt_stage
|
| 16 |
+
self.prompt_idmap = {i: idx for idx, i in enumerate(self.prompt_stage)}
|
| 17 |
+
block = kwargs.get("block")
|
| 18 |
+
self.prompt_model = nn.ModuleList([block for _ in range(len(self.prompt_stage))])
|
| 19 |
+
|
| 20 |
+
def forward(self, features, prompt_depth, prompt_mask, patch_h, patch_w):
|
| 21 |
+
for i in range(len(features)):
|
| 22 |
+
if i not in self.prompt_stage: # prompt_stage = [3]
|
| 23 |
+
continue
|
| 24 |
+
features[i][0] = self.prompt_model[self.prompt_idmap[i]](
|
| 25 |
+
features[i][0],
|
| 26 |
+
prompt_depth,
|
| 27 |
+
prompt_mask,
|
| 28 |
+
patch_h,
|
| 29 |
+
patch_w,
|
| 30 |
+
)
|
| 31 |
+
return features
|
InfiniDepth/model/block/prompt_models/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (1.49 kB). View file
|
|
|
InfiniDepth/model/block/prompt_models/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (2.44 kB). View file
|
|
|
InfiniDepth/model/block/prompt_models/__pycache__/crossattn.cpython-310.pyc
ADDED
|
Binary file (5.79 kB). View file
|
|
|
InfiniDepth/model/block/prompt_models/__pycache__/diffattn.cpython-310.pyc
ADDED
|
Binary file (142 Bytes). View file
|
|
|
InfiniDepth/model/block/prompt_models/__pycache__/rope.cpython-310.pyc
ADDED
|
Binary file (8.28 kB). View file
|
|
|
InfiniDepth/model/block/prompt_models/__pycache__/rope.cpython-311.pyc
ADDED
|
Binary file (13 kB). View file
|
|
|
InfiniDepth/model/block/prompt_models/__pycache__/sam.cpython-310.pyc
ADDED
|
Binary file (3.51 kB). View file
|
|
|
InfiniDepth/model/block/prompt_models/__pycache__/sam.cpython-311.pyc
ADDED
|
Binary file (5.48 kB). View file
|
|
|
InfiniDepth/model/block/prompt_models/__pycache__/selfattn.cpython-310.pyc
ADDED
|
Binary file (8.51 kB). View file
|
|
|
InfiniDepth/model/block/prompt_models/__pycache__/selfattn.cpython-311.pyc
ADDED
|
Binary file (15.9 kB). View file
|
|
|
InfiniDepth/model/block/prompt_models/crossattn.py
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from typing import Tuple
|
| 4 |
+
from ..perceive_io import LayerScale, MemEffCrossAttention, Mlp
|
| 5 |
+
from .utils.pe_utils import PositionEmbeddingRandom
|
| 6 |
+
from torch import Tensor
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class CrossAttnPromptModel(nn.Module):
|
| 10 |
+
def __init__(
|
| 11 |
+
self,
|
| 12 |
+
transformer_dim: int = 1024,
|
| 13 |
+
num_blocks: int = 1,
|
| 14 |
+
num_heads: int = 4,
|
| 15 |
+
pe: str = "normal",
|
| 16 |
+
image_pe_method: str = "patch", # image
|
| 17 |
+
**kwargs,
|
| 18 |
+
) -> None:
|
| 19 |
+
"""
|
| 20 |
+
Predicts masks given an image and prompt embeddings, using a
|
| 21 |
+
transformer architecture.
|
| 22 |
+
|
| 23 |
+
Arguments:
|
| 24 |
+
transformer_dim (int): the channel dimension of the transformer
|
| 25 |
+
transformer (nn.Module): the transformer used to predict masks
|
| 26 |
+
num_multimask_outputs (int): the number of masks to predict
|
| 27 |
+
when disambiguating masks
|
| 28 |
+
activation (nn.Module): the type of activation to use when
|
| 29 |
+
upscaling masks
|
| 30 |
+
iou_head_depth (int): the depth of the MLP used to predict
|
| 31 |
+
mask quality
|
| 32 |
+
iou_head_hidden_dim (int): the hidden dimension of the MLP
|
| 33 |
+
used to predict mask quality
|
| 34 |
+
"""
|
| 35 |
+
super().__init__()
|
| 36 |
+
self.pe = pe
|
| 37 |
+
pe_dim = transformer_dim // 2
|
| 38 |
+
if self.pe == "apply":
|
| 39 |
+
pe_dim = pe_dim // num_heads
|
| 40 |
+
self.pe_layer = PositionEmbeddingRandom(pe_dim, image_pe_method=image_pe_method)
|
| 41 |
+
self.prompt_blocks = nn.ModuleList(
|
| 42 |
+
[
|
| 43 |
+
CrossAttenPromptBlock(dim=transformer_dim, num_heads=num_heads, first_block=(i == 0), pe=pe)
|
| 44 |
+
for i in range(num_blocks)
|
| 45 |
+
]
|
| 46 |
+
)
|
| 47 |
+
self.depth2feature = nn.Sequential(
|
| 48 |
+
nn.Linear(1, transformer_dim // 2),
|
| 49 |
+
nn.GELU(),
|
| 50 |
+
nn.Linear(transformer_dim // 2, transformer_dim),
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
def forward(
|
| 54 |
+
self,
|
| 55 |
+
image_embeddings: torch.Tensor,
|
| 56 |
+
prompt_depth: torch.Tensor,
|
| 57 |
+
prompt_mask: torch.Tensor,
|
| 58 |
+
patch_h: int,
|
| 59 |
+
patch_w: int,
|
| 60 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 61 |
+
"""
|
| 62 |
+
Predict masks given image and prompt embeddings.
|
| 63 |
+
Arguments:
|
| 64 |
+
image_embeddings (torch.Tensor): the embeddings from the image encoder
|
| 65 |
+
image_pe (torch.Tensor): positional encoding with the shape of image_embeddings
|
| 66 |
+
sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes
|
| 67 |
+
dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs
|
| 68 |
+
multimask_output (bool): Whether to return multiple masks or a single
|
| 69 |
+
mask.
|
| 70 |
+
Returns:
|
| 71 |
+
torch.Tensor: batched predicted masks
|
| 72 |
+
torch.Tensor: batched predictions of mask quality
|
| 73 |
+
"""
|
| 74 |
+
B, _, H, W = prompt_depth.shape
|
| 75 |
+
image_pe = self.pe_layer((patch_h, patch_w)).permute(1, 2, 0) # CxHxW -> HxWxC
|
| 76 |
+
prompt_embeddings_list = []
|
| 77 |
+
image_embeddings_list = []
|
| 78 |
+
for b in range(B):
|
| 79 |
+
valid_pts_num = (prompt_mask[b, 0] > 0.0).sum()
|
| 80 |
+
if valid_pts_num == 0:
|
| 81 |
+
image_embeddings_item = image_embeddings[b : (b + 1)]
|
| 82 |
+
image_embeddings_list.append(image_embeddings_item)
|
| 83 |
+
continue
|
| 84 |
+
sparse_depth_pos = (prompt_mask[b, 0] > 0.0).nonzero().float()
|
| 85 |
+
sparse_depth_pos[:, 0] = (sparse_depth_pos[:, 0] + 0.5) / H
|
| 86 |
+
sparse_depth_pos[:, 1] = (sparse_depth_pos[:, 1] + 0.5) / W
|
| 87 |
+
sparse_depth = prompt_depth[b, 0][prompt_mask[b, 0] > 0.0]
|
| 88 |
+
prompt_embeddings = self.depth2feature(sparse_depth[:, None])[None, ...] # 1, N, C
|
| 89 |
+
prompt_pe = self.pe_layer._pe_encoding(sparse_depth_pos[None, :, [1, 0]]) # 1, N, C
|
| 90 |
+
query_pe = image_pe.reshape(1, -1, image_pe.shape[-1])
|
| 91 |
+
prompt = prompt_embeddings # + prompt_pe
|
| 92 |
+
query = image_embeddings[b : (b + 1)] # + query_pe
|
| 93 |
+
for block in self.prompt_blocks:
|
| 94 |
+
query, prompt = block(query, query_pe, prompt, prompt_pe)
|
| 95 |
+
image_embeddings_list.append(query[..., : image_embeddings.shape[-1]])
|
| 96 |
+
prompt_embeddings_list.append(prompt)
|
| 97 |
+
image_embeddings = torch.cat(image_embeddings_list, dim=0)
|
| 98 |
+
return image_embeddings
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
class CrossAttenPromptBlock(nn.Module):
|
| 102 |
+
"""
|
| 103 |
+
Self-attention block for prompt-based processing that handles both query and context features.
|
| 104 |
+
Supports different positional encoding strategies.
|
| 105 |
+
"""
|
| 106 |
+
|
| 107 |
+
def __init__(
|
| 108 |
+
self,
|
| 109 |
+
dim: int,
|
| 110 |
+
num_heads: int,
|
| 111 |
+
init_values: float = 0.0,
|
| 112 |
+
first_block: bool = False,
|
| 113 |
+
pe: str = "normal",
|
| 114 |
+
**kwargs,
|
| 115 |
+
) -> None:
|
| 116 |
+
super().__init__()
|
| 117 |
+
self.first_block = first_block
|
| 118 |
+
self.pe = pe
|
| 119 |
+
|
| 120 |
+
# Attention components
|
| 121 |
+
self.norm1_x = nn.LayerNorm(dim)
|
| 122 |
+
self.norm1_x_after = nn.LayerNorm(dim)
|
| 123 |
+
self.attn_x = MemEffCrossAttention(dim, context_dim=dim, num_heads=num_heads, pe=pe)
|
| 124 |
+
self.ls1_x = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
| 125 |
+
self.norm1_context = nn.LayerNorm(dim)
|
| 126 |
+
self.attn_context = MemEffCrossAttention(dim, context_dim=dim, num_heads=num_heads, pe=pe)
|
| 127 |
+
self.ls1_context = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
| 128 |
+
|
| 129 |
+
# MLP components
|
| 130 |
+
self.norm2_x = nn.LayerNorm(dim)
|
| 131 |
+
self.mlp_x = Mlp(dim, hidden_features=dim * 4)
|
| 132 |
+
self.ls2_x = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
| 133 |
+
self.norm2_context = nn.LayerNorm(dim)
|
| 134 |
+
self.mlp_context = Mlp(dim, hidden_features=dim * 4)
|
| 135 |
+
self.ls2_context = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
| 136 |
+
|
| 137 |
+
def forward(self, x: Tensor, x_pe: Tensor, context: Tensor, context_pe: Tensor) -> Tuple[Tensor, Tensor]:
|
| 138 |
+
# Apply positional encoding if this is the first block and using normal PE
|
| 139 |
+
if self.pe == "normal" and self.first_block:
|
| 140 |
+
x = x + x_pe
|
| 141 |
+
context = context + context_pe
|
| 142 |
+
|
| 143 |
+
# Handle positional encoding concatenation if needed
|
| 144 |
+
if self.pe != "normal":
|
| 145 |
+
x = x + self.ls1_x(
|
| 146 |
+
self.attn_x(self.norm1_x(x), context=self.norm1_context(context), x_pe=x_pe, context_pe=context_pe)
|
| 147 |
+
)
|
| 148 |
+
context = context + self.ls1_context(
|
| 149 |
+
self.attn_context(
|
| 150 |
+
self.norm1_context(context), context=self.norm1_x_after(x), x_pe=context_pe, context_pe=x_pe
|
| 151 |
+
)
|
| 152 |
+
)
|
| 153 |
+
else:
|
| 154 |
+
# Apply standard attention
|
| 155 |
+
x = x + self.ls1_x(self.attn_x(self.norm1_x(x), context=self.norm1_context(context)))
|
| 156 |
+
context = context + self.ls1_context(
|
| 157 |
+
self.attn_context(self.norm1_context(context), context=self.norm1_x_after(x))
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
# Apply MLP
|
| 161 |
+
x = x + self.ls2_x(self.mlp_x(self.norm2_x(x)))
|
| 162 |
+
context = context + self.ls2_context(self.mlp_context(self.norm2_context(context)))
|
| 163 |
+
|
| 164 |
+
return x, context
|
InfiniDepth/model/block/prompt_models/rope.py
ADDED
|
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, Tuple
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
|
| 6 |
+
if torch.cuda.is_available():
|
| 7 |
+
acc_dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16
|
| 8 |
+
else:
|
| 9 |
+
acc_dtype = torch.float16
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class PositionGetter:
|
| 13 |
+
"""Generates and caches 2D spatial positions for patches in a grid.
|
| 14 |
+
|
| 15 |
+
This class efficiently manages the generation of spatial coordinates for patches
|
| 16 |
+
in a 2D grid, caching results to avoid redundant computations.
|
| 17 |
+
|
| 18 |
+
Attributes:
|
| 19 |
+
position_cache: Dictionary storing precomputed position tensors for different
|
| 20 |
+
grid dimensions.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
def __init__(self):
|
| 24 |
+
"""Initializes the position generator with an empty cache."""
|
| 25 |
+
self.position_cache: Dict[Tuple[int, int], torch.Tensor] = {}
|
| 26 |
+
|
| 27 |
+
def __call__(self, batch_size: int, height: int, width: int, device: torch.device) -> torch.Tensor:
|
| 28 |
+
"""Generates spatial positions for a batch of patches.
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
batch_size: Number of samples in the batch.
|
| 32 |
+
height: Height of the grid in patches.
|
| 33 |
+
width: Width of the grid in patches.
|
| 34 |
+
device: Target device for the position tensor.
|
| 35 |
+
|
| 36 |
+
Returns:
|
| 37 |
+
Tensor of shape (batch_size, height*width, 2) containing y,x coordinates
|
| 38 |
+
for each position in the grid, repeated for each batch item.
|
| 39 |
+
"""
|
| 40 |
+
if (height, width) not in self.position_cache:
|
| 41 |
+
y_coords = torch.arange(height, device=device)
|
| 42 |
+
x_coords = torch.arange(width, device=device)
|
| 43 |
+
positions = torch.cartesian_prod(y_coords, x_coords)
|
| 44 |
+
self.position_cache[height, width] = positions
|
| 45 |
+
|
| 46 |
+
cached_positions = self.position_cache[height, width]
|
| 47 |
+
return cached_positions.view(1, height * width, 2).expand(batch_size, -1, -1).clone()
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class RotaryPositionEmbedding2D(nn.Module):
|
| 51 |
+
"""2D Rotary Position Embedding implementation.
|
| 52 |
+
|
| 53 |
+
This module applies rotary position embeddings to input tokens based on their
|
| 54 |
+
2D spatial positions. It handles the position-dependent rotation of features
|
| 55 |
+
separately for vertical and horizontal dimensions.
|
| 56 |
+
|
| 57 |
+
Args:
|
| 58 |
+
frequency: Base frequency for the position embeddings. Default: 100.0
|
| 59 |
+
scaling_factor: Scaling factor for frequency computation. Default: 1.0
|
| 60 |
+
|
| 61 |
+
Attributes:
|
| 62 |
+
base_frequency: Base frequency for computing position embeddings.
|
| 63 |
+
scaling_factor: Factor to scale the computed frequencies.
|
| 64 |
+
frequency_cache: Cache for storing precomputed frequency components.
|
| 65 |
+
"""
|
| 66 |
+
|
| 67 |
+
def __init__(self, frequency: float = 100.0, scaling_factor: float = 1.0, feat_dim: int = 1024):
|
| 68 |
+
"""Initializes the 2D RoPE module."""
|
| 69 |
+
super().__init__()
|
| 70 |
+
self.base_frequency = frequency
|
| 71 |
+
self.scaling_factor = scaling_factor
|
| 72 |
+
self.frequency_cache: Dict[Tuple, Tuple[torch.Tensor, torch.Tensor]] = {}
|
| 73 |
+
self.patch_size = 14
|
| 74 |
+
self.feat_dim = feat_dim
|
| 75 |
+
|
| 76 |
+
def _compute_frequency_components(
|
| 77 |
+
self, dim: int, seq_len: int, device: torch.device, dtype: torch.dtype
|
| 78 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 79 |
+
"""Computes frequency components for rotary embeddings.
|
| 80 |
+
|
| 81 |
+
Args:
|
| 82 |
+
dim: Feature dimension (must be even).
|
| 83 |
+
seq_len: Maximum sequence length.
|
| 84 |
+
device: Target device for computations.
|
| 85 |
+
dtype: Data type for the computed tensors.
|
| 86 |
+
|
| 87 |
+
Returns:
|
| 88 |
+
Tuple of (cosine, sine) tensors for frequency components.
|
| 89 |
+
"""
|
| 90 |
+
cache_key = (dim, seq_len, device, dtype)
|
| 91 |
+
if cache_key not in self.frequency_cache:
|
| 92 |
+
# Compute frequency bands
|
| 93 |
+
exponents = torch.arange(0, dim, 2, device=device).float() / dim
|
| 94 |
+
inv_freq = 1.0 / (self.base_frequency**exponents)
|
| 95 |
+
|
| 96 |
+
# Generate position-dependent frequencies
|
| 97 |
+
positions = torch.arange(seq_len, device=device, dtype=inv_freq.dtype)
|
| 98 |
+
angles = torch.einsum("i,j->ij", positions, inv_freq)
|
| 99 |
+
|
| 100 |
+
# Compute and cache frequency components
|
| 101 |
+
angles = angles.to(dtype)
|
| 102 |
+
angles = torch.cat((angles, angles), dim=-1)
|
| 103 |
+
cos_components = angles.cos().to(dtype)
|
| 104 |
+
sin_components = angles.sin().to(dtype)
|
| 105 |
+
self.frequency_cache[cache_key] = (cos_components, sin_components)
|
| 106 |
+
|
| 107 |
+
return self.frequency_cache[cache_key]
|
| 108 |
+
|
| 109 |
+
@staticmethod
|
| 110 |
+
def _rotate_features(x: torch.Tensor) -> torch.Tensor:
|
| 111 |
+
"""Performs feature rotation by splitting and recombining feature dimensions.
|
| 112 |
+
|
| 113 |
+
Args:
|
| 114 |
+
x: Input tensor to rotate.
|
| 115 |
+
|
| 116 |
+
Returns:
|
| 117 |
+
Rotated feature tensor.
|
| 118 |
+
"""
|
| 119 |
+
feature_dim = x.shape[-1]
|
| 120 |
+
x1, x2 = x[..., : feature_dim // 2], x[..., feature_dim // 2 :]
|
| 121 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 122 |
+
|
| 123 |
+
def _apply_1d_rope(
|
| 124 |
+
self,
|
| 125 |
+
tokens: torch.Tensor,
|
| 126 |
+
positions: torch.Tensor,
|
| 127 |
+
cos_comp: torch.Tensor,
|
| 128 |
+
sin_comp: torch.Tensor,
|
| 129 |
+
) -> torch.Tensor:
|
| 130 |
+
"""Applies 1D rotary position embeddings along one dimension.
|
| 131 |
+
|
| 132 |
+
Args:
|
| 133 |
+
tokens: Input token features.
|
| 134 |
+
positions: Position indices.
|
| 135 |
+
cos_comp: Cosine components for rotation.
|
| 136 |
+
sin_comp: Sine components for rotation.
|
| 137 |
+
|
| 138 |
+
Returns:
|
| 139 |
+
Tokens with applied rotary position embeddings.
|
| 140 |
+
"""
|
| 141 |
+
# Embed positions with frequency components
|
| 142 |
+
cos = F.embedding(positions, cos_comp)[:, None, :, :]
|
| 143 |
+
sin = F.embedding(positions, sin_comp)[:, None, :, :]
|
| 144 |
+
|
| 145 |
+
# Apply rotation
|
| 146 |
+
return (tokens * cos) + (self._rotate_features(tokens) * sin)
|
| 147 |
+
|
| 148 |
+
def _forward(self, tokens: torch.Tensor, positions: torch.Tensor) -> torch.Tensor:
|
| 149 |
+
"""Applies 2D rotary position embeddings to input tokens.
|
| 150 |
+
|
| 151 |
+
Args:
|
| 152 |
+
tokens: Input tensor of shape (batch_size, n_heads, n_tokens, dim).
|
| 153 |
+
The feature dimension (dim) must be divisible by 4.
|
| 154 |
+
positions: Position tensor of shape (batch_size, n_tokens, 2) containing
|
| 155 |
+
the y and x coordinates for each token.
|
| 156 |
+
|
| 157 |
+
Returns:
|
| 158 |
+
Tensor of same shape as input with applied 2D rotary position embeddings.
|
| 159 |
+
|
| 160 |
+
Raises:
|
| 161 |
+
AssertionError: If input dimensions are invalid or positions are malformed.
|
| 162 |
+
"""
|
| 163 |
+
# Validate inputs
|
| 164 |
+
assert tokens.size(-1) % 2 == 0, "Feature dimension must be even"
|
| 165 |
+
assert positions.ndim == 3 and positions.shape[-1] == 2, "Positions must have shape (batch_size, n_tokens, 2)"
|
| 166 |
+
|
| 167 |
+
# Compute feature dimension for each spatial direction
|
| 168 |
+
feature_dim = tokens.size(-1) // 2
|
| 169 |
+
|
| 170 |
+
# Get frequency components
|
| 171 |
+
max_position = int(positions.max()) + 1
|
| 172 |
+
cos_comp, sin_comp = self._compute_frequency_components(feature_dim, max_position, tokens.device, tokens.dtype)
|
| 173 |
+
|
| 174 |
+
# Split features for vertical and horizontal processing
|
| 175 |
+
vertical_features, horizontal_features = tokens.chunk(2, dim=-1)
|
| 176 |
+
|
| 177 |
+
# Apply RoPE separately for each dimension
|
| 178 |
+
vertical_features = self._apply_1d_rope(vertical_features, positions[..., 0], cos_comp, sin_comp)
|
| 179 |
+
horizontal_features = self._apply_1d_rope(horizontal_features, positions[..., 1], cos_comp, sin_comp)
|
| 180 |
+
|
| 181 |
+
# Combine processed features
|
| 182 |
+
return torch.cat((vertical_features, horizontal_features), dim=-1)
|
| 183 |
+
|
| 184 |
+
def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
|
| 185 |
+
"""Positionally encode points that are normalized to [0,1]."""
|
| 186 |
+
# assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
|
| 187 |
+
max_position = int(coords.max()) + 1
|
| 188 |
+
cos_comp, sin_comp = self._compute_frequency_components(self.feat_dim, max_position, coords.device, acc_dtype)
|
| 189 |
+
vertical_cos = F.embedding(coords[..., 0], cos_comp)
|
| 190 |
+
vertical_sin = F.embedding(coords[..., 0], sin_comp)
|
| 191 |
+
horizontal_cos = F.embedding(coords[..., 1], cos_comp)
|
| 192 |
+
horizontal_sin = F.embedding(coords[..., 1], sin_comp)
|
| 193 |
+
# outputs d_1 x ... x d_n x C shape
|
| 194 |
+
return torch.cat((vertical_cos, vertical_sin, horizontal_cos, horizontal_sin), dim=-1)
|
| 195 |
+
|
| 196 |
+
def forward_encoding(self, size: Tuple[int, int], device: torch.device) -> torch.Tensor:
|
| 197 |
+
"""Generate positional encoding for a grid of the specified size."""
|
| 198 |
+
height, width = size
|
| 199 |
+
y_coords = torch.arange(height, device=device) * (self.patch_size * 2) + self.patch_size - 1
|
| 200 |
+
x_coords = torch.arange(width, device=device) * (self.patch_size * 2) + self.patch_size - 1
|
| 201 |
+
positions = torch.cartesian_prod(y_coords, x_coords) # h, w
|
| 202 |
+
max_position = int(positions.max()) + 1
|
| 203 |
+
cos_comp, sin_comp = self._compute_frequency_components(self.feat_dim, max_position, device, acc_dtype)
|
| 204 |
+
vertical_cos = F.embedding(positions[..., 0], cos_comp)
|
| 205 |
+
vertical_sin = F.embedding(positions[..., 0], sin_comp)
|
| 206 |
+
horizontal_cos = F.embedding(positions[..., 1], cos_comp)
|
| 207 |
+
horizontal_sin = F.embedding(positions[..., 1], sin_comp)
|
| 208 |
+
return (
|
| 209 |
+
torch.cat((vertical_cos, vertical_sin, horizontal_cos, horizontal_sin), dim=-1)
|
| 210 |
+
.reshape(height, width, -1)
|
| 211 |
+
.permute(2, 0, 1)
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
def forward(self, size: Tuple[int, int], device: torch.device) -> torch.Tensor:
|
| 215 |
+
return self.forward_encoding(size, device)
|
InfiniDepth/model/block/prompt_models/sam.py
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from typing import Tuple, Type
|
| 8 |
+
import torch
|
| 9 |
+
from .utils.pe_utils import PositionEmbeddingRandom
|
| 10 |
+
from .utils.transformer import TwoWayTransformer
|
| 11 |
+
from torch import nn
|
| 12 |
+
|
| 13 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 14 |
+
# All rights reserved.
|
| 15 |
+
|
| 16 |
+
# This source code is licensed under the license found in the
|
| 17 |
+
# LICENSE file in the root directory of this source tree.
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class SAMPromptModel(nn.Module):
|
| 21 |
+
def __init__(
|
| 22 |
+
self,
|
| 23 |
+
*,
|
| 24 |
+
transformer_dim: int,
|
| 25 |
+
mlp_dim: int = 2048,
|
| 26 |
+
num_heads: int = 8,
|
| 27 |
+
activation: Type[nn.Module] = nn.GELU,
|
| 28 |
+
) -> None:
|
| 29 |
+
"""
|
| 30 |
+
Predicts masks given an image and prompt embeddings, using a
|
| 31 |
+
transformer architecture.
|
| 32 |
+
|
| 33 |
+
Arguments:
|
| 34 |
+
transformer_dim (int): the channel dimension of the transformer
|
| 35 |
+
transformer (nn.Module): the transformer used to predict masks
|
| 36 |
+
num_multimask_outputs (int): the number of masks to predict
|
| 37 |
+
when disambiguating masks
|
| 38 |
+
activation (nn.Module): the type of activation to use when
|
| 39 |
+
upscaling masks
|
| 40 |
+
iou_head_depth (int): the depth of the MLP used to predict
|
| 41 |
+
mask quality
|
| 42 |
+
iou_head_hidden_dim (int): the hidden dimension of the MLP
|
| 43 |
+
used to predict mask quality
|
| 44 |
+
"""
|
| 45 |
+
super().__init__()
|
| 46 |
+
self.transformer_dim = transformer_dim
|
| 47 |
+
self.transformer = TwoWayTransformer(
|
| 48 |
+
depth=2, embedding_dim=transformer_dim, num_heads=num_heads, mlp_dim=mlp_dim
|
| 49 |
+
)
|
| 50 |
+
self.pe_layer = PositionEmbeddingRandom(transformer_dim // 2)
|
| 51 |
+
self.depth2feature = nn.Sequential(
|
| 52 |
+
nn.Linear(1, transformer_dim // 2), nn.ReLU(True), nn.Linear(transformer_dim // 2, transformer_dim)
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
def forward(
|
| 56 |
+
self,
|
| 57 |
+
image_embeddings: torch.Tensor,
|
| 58 |
+
prompt_depth: torch.Tensor,
|
| 59 |
+
prompt_mask: torch.Tensor,
|
| 60 |
+
patch_h: int,
|
| 61 |
+
patch_w: int,
|
| 62 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 63 |
+
"""
|
| 64 |
+
Predict masks given image and prompt embeddings.
|
| 65 |
+
Arguments:
|
| 66 |
+
image_embeddings (torch.Tensor): the embeddings from the image encoder
|
| 67 |
+
image_pe (torch.Tensor): positional encoding with the shape of image_embeddings
|
| 68 |
+
sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes
|
| 69 |
+
dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs
|
| 70 |
+
multimask_output (bool): Whether to return multiple masks or a single
|
| 71 |
+
mask.
|
| 72 |
+
Returns:
|
| 73 |
+
torch.Tensor: batched predicted masks
|
| 74 |
+
torch.Tensor: batched predictions of mask quality
|
| 75 |
+
"""
|
| 76 |
+
B, _, H, W = prompt_depth.shape
|
| 77 |
+
image_pe = self.pe_layer((patch_h, patch_w)).permute(1, 2, 0) # CxHxW -> HxWxC
|
| 78 |
+
|
| 79 |
+
# prompt_embeddings_list = []
|
| 80 |
+
image_embeddings_list = []
|
| 81 |
+
for b in range(B):
|
| 82 |
+
valid_pts_num = (prompt_mask[b, 0] > 0.0).sum()
|
| 83 |
+
if valid_pts_num == 0:
|
| 84 |
+
image_embeddings_item = image_embeddings[b : (b + 1)]
|
| 85 |
+
image_embeddings_list.append(image_embeddings_item)
|
| 86 |
+
continue
|
| 87 |
+
sparse_depth_pos = (prompt_mask[b, 0] > 0.0).nonzero().float()
|
| 88 |
+
sparse_depth_pos[:, 0] = (sparse_depth_pos[:, 0] + 0.5) / H
|
| 89 |
+
sparse_depth_pos[:, 1] = (sparse_depth_pos[:, 1] + 0.5) / W
|
| 90 |
+
sparse_depth = prompt_depth[b, 0][prompt_mask[b, 0] > 0.0]
|
| 91 |
+
prompt_embeddings = self.depth2feature(sparse_depth[:, None])[None, ...] # 1, N, C
|
| 92 |
+
prompt_pe = self.pe_layer._pe_encoding(sparse_depth_pos[None, :, [1, 0]]) # 1, N, C
|
| 93 |
+
prompt_embeddings_item, image_embeddings_item = self.transformer(
|
| 94 |
+
image_embeddings[b : (b + 1)],
|
| 95 |
+
image_pe.reshape(1, -1, image_pe.shape[-1]),
|
| 96 |
+
prompt_embeddings,
|
| 97 |
+
prompt_pe,
|
| 98 |
+
)
|
| 99 |
+
image_embeddings_list.append(image_embeddings_item)
|
| 100 |
+
image_embeddings = torch.cat(image_embeddings_list, dim=0)
|
| 101 |
+
return image_embeddings
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
# # Lightly adapted from
|
| 105 |
+
# # https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa
|
| 106 |
+
# class MLP(nn.Module):
|
| 107 |
+
# def __init__(
|
| 108 |
+
# self,
|
| 109 |
+
# input_dim: int,
|
| 110 |
+
# hidden_dim: int,
|
| 111 |
+
# output_dim: int,
|
| 112 |
+
# num_layers: int,
|
| 113 |
+
# sigmoid_output: bool = False,
|
| 114 |
+
# ) -> None:
|
| 115 |
+
# super().__init__()
|
| 116 |
+
# self.num_layers = num_layers
|
| 117 |
+
# h = [hidden_dim] * (num_layers - 1)
|
| 118 |
+
# self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
|
| 119 |
+
# self.sigmoid_output = sigmoid_output
|
| 120 |
+
|
| 121 |
+
# def forward(self, x):
|
| 122 |
+
# for i, layer in enumerate(self.layers):
|
| 123 |
+
# x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
|
| 124 |
+
# if self.sigmoid_output:
|
| 125 |
+
# x = F.sigmoid(x)
|
| 126 |
+
# return x
|
InfiniDepth/model/block/prompt_models/selfattn.py
ADDED
|
@@ -0,0 +1,289 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Tuple
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from ..perceive_io import LayerScale, MemEffAttention, Mlp
|
| 5 |
+
from .rope import RotaryPositionEmbedding2D
|
| 6 |
+
from .utils.pe_utils import PositionEmbeddingRandom
|
| 7 |
+
from torch import Tensor
|
| 8 |
+
|
| 9 |
+
if torch.cuda.is_available():
|
| 10 |
+
acc_dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16
|
| 11 |
+
else:
|
| 12 |
+
acc_dtype = torch.float16
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class SelfAttnPromptModel(nn.Module):
|
| 16 |
+
def __init__(
|
| 17 |
+
self,
|
| 18 |
+
transformer_dim: int = 1024,
|
| 19 |
+
num_blocks: int = 1,
|
| 20 |
+
num_heads: int = 4,
|
| 21 |
+
pe: str = "normal",
|
| 22 |
+
image_pe_method: str = "patch", # image
|
| 23 |
+
**kwargs,
|
| 24 |
+
) -> None:
|
| 25 |
+
"""
|
| 26 |
+
Predicts masks given an image and prompt embeddings, using a
|
| 27 |
+
transformer architecture.
|
| 28 |
+
|
| 29 |
+
Arguments:
|
| 30 |
+
transformer_dim (int): the channel dimension of the transformer
|
| 31 |
+
transformer (nn.Module): the transformer used to predict masks
|
| 32 |
+
num_multimask_outputs (int): the number of masks to predict
|
| 33 |
+
when disambiguating masks
|
| 34 |
+
activation (nn.Module): the type of activation to use when
|
| 35 |
+
upscaling masks
|
| 36 |
+
iou_head_depth (int): the depth of the MLP used to predict
|
| 37 |
+
mask quality
|
| 38 |
+
iou_head_hidden_dim (int): the hidden dimension of the MLP
|
| 39 |
+
used to predict mask quality
|
| 40 |
+
"""
|
| 41 |
+
super().__init__()
|
| 42 |
+
self.pe = pe
|
| 43 |
+
pe_dim = transformer_dim // 2
|
| 44 |
+
if self.pe == "apply":
|
| 45 |
+
pe_dim = pe_dim // num_heads
|
| 46 |
+
self.pe_layer = PositionEmbeddingRandom(pe_dim, image_pe_method=image_pe_method)
|
| 47 |
+
self.prompt_blocks = nn.ModuleList(
|
| 48 |
+
[
|
| 49 |
+
SelfAttenPromptBlock(dim=transformer_dim, num_heads=num_heads, first_block=(i == 0), pe=pe)
|
| 50 |
+
for i in range(num_blocks)
|
| 51 |
+
]
|
| 52 |
+
)
|
| 53 |
+
self.depth2feature = nn.Sequential(
|
| 54 |
+
nn.Linear(1, transformer_dim // 2),
|
| 55 |
+
nn.GELU(),
|
| 56 |
+
nn.Linear(transformer_dim // 2, transformer_dim),
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
def forward(
|
| 60 |
+
self,
|
| 61 |
+
image_embeddings: torch.Tensor,
|
| 62 |
+
prompt_depth: torch.Tensor,
|
| 63 |
+
prompt_mask: torch.Tensor,
|
| 64 |
+
patch_h: int,
|
| 65 |
+
patch_w: int,
|
| 66 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 67 |
+
"""
|
| 68 |
+
Predict masks given image and prompt embeddings.
|
| 69 |
+
Arguments:
|
| 70 |
+
image_embeddings (torch.Tensor): the embeddings from the image encoder
|
| 71 |
+
image_pe (torch.Tensor): positional encoding with the shape of image_embeddings
|
| 72 |
+
sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes
|
| 73 |
+
dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs
|
| 74 |
+
multimask_output (bool): Whether to return multiple masks or a single
|
| 75 |
+
mask.
|
| 76 |
+
Returns:
|
| 77 |
+
torch.Tensor: batched predicted masks
|
| 78 |
+
torch.Tensor: batched predictions of mask quality
|
| 79 |
+
"""
|
| 80 |
+
B, _, H, W = prompt_depth.shape
|
| 81 |
+
image_pe = self.pe_layer((patch_h, patch_w)).permute(1, 2, 0) # CxHxW -> HxWxC
|
| 82 |
+
prompt_embeddings_list = []
|
| 83 |
+
image_embeddings_list = []
|
| 84 |
+
for b in range(B):
|
| 85 |
+
valid_pts_num = (prompt_mask[b, 0] > 0.0).sum()
|
| 86 |
+
if valid_pts_num == 0:
|
| 87 |
+
image_embeddings_item = image_embeddings[b : (b + 1)]
|
| 88 |
+
image_embeddings_list.append(image_embeddings_item)
|
| 89 |
+
continue
|
| 90 |
+
sparse_depth_pos = (prompt_mask[b, 0] > 0.0).nonzero().float()
|
| 91 |
+
sparse_depth_pos[:, 0] = (sparse_depth_pos[:, 0] + 0.5) / H
|
| 92 |
+
sparse_depth_pos[:, 1] = (sparse_depth_pos[:, 1] + 0.5) / W
|
| 93 |
+
sparse_depth = prompt_depth[b, 0][prompt_mask[b, 0] > 0.0]
|
| 94 |
+
prompt_embeddings = self.depth2feature(sparse_depth[:, None])[None, ...] # 1, N, C
|
| 95 |
+
prompt_pe = self.pe_layer._pe_encoding(sparse_depth_pos[None, :, [1, 0]]) # 1, N, C
|
| 96 |
+
query_pe = image_pe.reshape(1, -1, image_pe.shape[-1])
|
| 97 |
+
prompt = prompt_embeddings # + prompt_pe
|
| 98 |
+
query = image_embeddings[b : (b + 1)] # + query_pe
|
| 99 |
+
with torch.autocast("cuda", enabled=True, dtype=acc_dtype):
|
| 100 |
+
for block in self.prompt_blocks:
|
| 101 |
+
query, prompt = block(query, query_pe, prompt, prompt_pe)
|
| 102 |
+
image_embeddings_list.append(query[..., : image_embeddings.shape[-1]])
|
| 103 |
+
prompt_embeddings_list.append(prompt)
|
| 104 |
+
image_embeddings = torch.cat(image_embeddings_list, dim=0)
|
| 105 |
+
return image_embeddings
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
class SelfAttnRopePromptModel(nn.Module):
|
| 109 |
+
def __init__(
|
| 110 |
+
self,
|
| 111 |
+
transformer_dim: int = 1024,
|
| 112 |
+
num_blocks: int = 1,
|
| 113 |
+
num_heads: int = 4,
|
| 114 |
+
pe: str = "normal",
|
| 115 |
+
image_pe_method: str = "patch", # image
|
| 116 |
+
**kwargs,
|
| 117 |
+
) -> None:
|
| 118 |
+
"""
|
| 119 |
+
Predicts masks given an image and prompt embeddings, using a
|
| 120 |
+
transformer architecture.
|
| 121 |
+
|
| 122 |
+
Arguments:
|
| 123 |
+
transformer_dim (int): the channel dimension of the transformer
|
| 124 |
+
transformer (nn.Module): the transformer used to predict masks
|
| 125 |
+
num_multimask_outputs (int): the number of masks to predict
|
| 126 |
+
when disambiguating masks
|
| 127 |
+
activation (nn.Module): the type of activation to use when
|
| 128 |
+
upscaling masks
|
| 129 |
+
iou_head_depth (int): the depth of the MLP used to predict
|
| 130 |
+
mask quality
|
| 131 |
+
iou_head_hidden_dim (int): the hidden dimension of the MLP
|
| 132 |
+
used to predict mask quality
|
| 133 |
+
"""
|
| 134 |
+
super().__init__()
|
| 135 |
+
self.pe = pe
|
| 136 |
+
pe_dim = transformer_dim // 2
|
| 137 |
+
if self.pe == "apply":
|
| 138 |
+
pe_dim = pe_dim // num_heads
|
| 139 |
+
if self.pe.startswith("rope"):
|
| 140 |
+
self.pe_layer = RotaryPositionEmbedding2D(
|
| 141 |
+
frequency=float(self.pe.split("rope")[1]), feat_dim=pe_dim // num_heads
|
| 142 |
+
)
|
| 143 |
+
else:
|
| 144 |
+
self.pe_layer = PositionEmbeddingRandom(pe_dim, image_pe_method=image_pe_method)
|
| 145 |
+
self.prompt_blocks = nn.ModuleList(
|
| 146 |
+
[
|
| 147 |
+
SelfAttenPromptBlock(
|
| 148 |
+
dim=transformer_dim, num_heads=num_heads, first_block=(i == 0), pe=pe, use_sep=False
|
| 149 |
+
)
|
| 150 |
+
for i in range(num_blocks)
|
| 151 |
+
]
|
| 152 |
+
)
|
| 153 |
+
self.depth2feature = nn.Sequential(
|
| 154 |
+
nn.Linear(1, transformer_dim // 2),
|
| 155 |
+
nn.GELU(),
|
| 156 |
+
nn.Linear(transformer_dim // 2, transformer_dim),
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
def forward(
|
| 160 |
+
self,
|
| 161 |
+
image_embeddings: torch.Tensor,
|
| 162 |
+
prompt_depth: torch.Tensor,
|
| 163 |
+
prompt_mask: torch.Tensor,
|
| 164 |
+
patch_h: int,
|
| 165 |
+
patch_w: int,
|
| 166 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 167 |
+
"""
|
| 168 |
+
Predict masks given image and prompt embeddings.
|
| 169 |
+
Arguments:
|
| 170 |
+
image_embeddings (torch.Tensor): the embeddings from the image encoder
|
| 171 |
+
image_pe (torch.Tensor): positional encoding with the shape of image_embeddings
|
| 172 |
+
sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes
|
| 173 |
+
dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs
|
| 174 |
+
multimask_output (bool): Whether to return multiple masks or a single
|
| 175 |
+
mask.
|
| 176 |
+
Returns:
|
| 177 |
+
torch.Tensor: batched predicted masks
|
| 178 |
+
torch.Tensor: batched predictions of mask quality
|
| 179 |
+
"""
|
| 180 |
+
B, _, H, W = prompt_depth.shape
|
| 181 |
+
image_pe = self.pe_layer((patch_h, patch_w), device=prompt_depth.device).permute(1, 2, 0) # CxHxW -> HxWxC
|
| 182 |
+
prompt_embeddings_list = []
|
| 183 |
+
image_embeddings_list = []
|
| 184 |
+
for b in range(B):
|
| 185 |
+
valid_pts_num = (prompt_mask[b, 0] > 0.0).sum()
|
| 186 |
+
if valid_pts_num == 0:
|
| 187 |
+
image_embeddings_item = image_embeddings[b : (b + 1)]
|
| 188 |
+
image_embeddings_list.append(image_embeddings_item)
|
| 189 |
+
continue
|
| 190 |
+
sparse_depth_pos = (prompt_mask[b, 0] > 0.0).nonzero().int()
|
| 191 |
+
sparse_depth_pos[:, 0] = sparse_depth_pos[:, 0] * 2
|
| 192 |
+
sparse_depth_pos[:, 1] = sparse_depth_pos[:, 1] * 2
|
| 193 |
+
sparse_depth = prompt_depth[b, 0][prompt_mask[b, 0] > 0.0]
|
| 194 |
+
prompt_embeddings = self.depth2feature(sparse_depth[:, None])[None, ...] # 1, N, C
|
| 195 |
+
prompt_pe = self.pe_layer._pe_encoding(sparse_depth_pos[None]) # 1, N, C
|
| 196 |
+
query_pe = image_pe.reshape(1, -1, image_pe.shape[-1])
|
| 197 |
+
prompt = prompt_embeddings # + prompt_pe
|
| 198 |
+
query = image_embeddings[b : (b + 1)] # + query_pe
|
| 199 |
+
with torch.autocast("cuda", enabled=True, dtype=acc_dtype):
|
| 200 |
+
for block in self.prompt_blocks:
|
| 201 |
+
query, prompt = block(query, query_pe, prompt, prompt_pe)
|
| 202 |
+
image_embeddings_list.append(query[..., : image_embeddings.shape[-1]])
|
| 203 |
+
prompt_embeddings_list.append(prompt)
|
| 204 |
+
image_embeddings = torch.cat(image_embeddings_list, dim=0)
|
| 205 |
+
return image_embeddings
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
class SelfAttenPromptBlock(nn.Module):
|
| 209 |
+
"""
|
| 210 |
+
Self-attention block for prompt-based processing that handles both query and context features.
|
| 211 |
+
Supports different positional encoding strategies.
|
| 212 |
+
"""
|
| 213 |
+
|
| 214 |
+
def __init__(
|
| 215 |
+
self,
|
| 216 |
+
dim: int,
|
| 217 |
+
num_heads: int,
|
| 218 |
+
init_values: float = 0.0,
|
| 219 |
+
first_block: bool = False,
|
| 220 |
+
pe: str = "normal",
|
| 221 |
+
use_sep: bool = True,
|
| 222 |
+
**kwargs,
|
| 223 |
+
) -> None:
|
| 224 |
+
super().__init__()
|
| 225 |
+
self.first_block = first_block
|
| 226 |
+
self.pe = pe
|
| 227 |
+
|
| 228 |
+
# Attention components
|
| 229 |
+
self.norm1 = nn.LayerNorm(dim)
|
| 230 |
+
self.attn = MemEffAttention(dim, num_heads=num_heads, pe=pe)
|
| 231 |
+
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
| 232 |
+
|
| 233 |
+
# MLP components
|
| 234 |
+
self.norm2 = nn.LayerNorm(dim)
|
| 235 |
+
self.mlp = Mlp(dim, hidden_features=dim * 4)
|
| 236 |
+
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
| 237 |
+
|
| 238 |
+
# Separator token for concatenating query and context
|
| 239 |
+
pe_dim = dim
|
| 240 |
+
self.use_sep = use_sep
|
| 241 |
+
if use_sep:
|
| 242 |
+
self.sep = nn.Parameter(torch.randn(1, 1, pe_dim))
|
| 243 |
+
else:
|
| 244 |
+
self.sep = None
|
| 245 |
+
|
| 246 |
+
# Special separator for positional encoding if needed
|
| 247 |
+
if self.use_sep:
|
| 248 |
+
if self.pe != "normal":
|
| 249 |
+
self.sep_pe = nn.Parameter(torch.randn(1, 1, pe_dim))
|
| 250 |
+
else:
|
| 251 |
+
self.sep_pe = None
|
| 252 |
+
|
| 253 |
+
def forward(self, x: Tensor, x_pe: Tensor, context: Tensor, context_pe: Tensor) -> Tuple[Tensor, Tensor]:
|
| 254 |
+
# Apply positional encoding if this is the first block and using normal PE
|
| 255 |
+
if self.pe == "normal" and self.first_block:
|
| 256 |
+
x = x + x_pe
|
| 257 |
+
context = context + context_pe
|
| 258 |
+
|
| 259 |
+
# Record original sequence lengths
|
| 260 |
+
x_len, context_len = x.shape[1], context.shape[1]
|
| 261 |
+
|
| 262 |
+
# Concatenate query, separator token, and context
|
| 263 |
+
if self.use_sep:
|
| 264 |
+
x = torch.cat([x, self.sep, context], dim=1)
|
| 265 |
+
else:
|
| 266 |
+
x = torch.cat([x, context], dim=1)
|
| 267 |
+
|
| 268 |
+
# Handle positional encoding concatenation if needed
|
| 269 |
+
if self.pe != "normal":
|
| 270 |
+
if self.use_sep:
|
| 271 |
+
x_pe = torch.cat([x_pe, self.sep_pe, context_pe], dim=1)
|
| 272 |
+
else:
|
| 273 |
+
x_pe = torch.cat([x_pe, context_pe], dim=1)
|
| 274 |
+
x = x + self.ls1(self.attn(self.norm1(x), x_pe))
|
| 275 |
+
else:
|
| 276 |
+
# Apply standard attention
|
| 277 |
+
x = x + self.ls1(self.attn(self.norm1(x)))
|
| 278 |
+
|
| 279 |
+
# Apply MLP
|
| 280 |
+
x = x + self.ls2(self.mlp(self.norm2(x)))
|
| 281 |
+
|
| 282 |
+
# Split back into query and context
|
| 283 |
+
query = x[:, :x_len, :]
|
| 284 |
+
if self.use_sep:
|
| 285 |
+
context = x[:, x_len + 1 : x_len + 1 + context_len, :]
|
| 286 |
+
else:
|
| 287 |
+
context = x[:, x_len : x_len + context_len, :]
|
| 288 |
+
|
| 289 |
+
return query, context
|
InfiniDepth/model/block/prompt_models/utils/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Prompt model utility modules."""
|
InfiniDepth/model/block/prompt_models/utils/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (189 Bytes). View file
|
|
|
InfiniDepth/model/block/prompt_models/utils/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (252 Bytes). View file
|
|
|
InfiniDepth/model/block/prompt_models/utils/__pycache__/pe_utils.cpython-310.pyc
ADDED
|
Binary file (2.9 kB). View file
|
|
|
InfiniDepth/model/block/prompt_models/utils/__pycache__/pe_utils.cpython-311.pyc
ADDED
|
Binary file (5.25 kB). View file
|
|
|
InfiniDepth/model/block/prompt_models/utils/__pycache__/transformer.cpython-310.pyc
ADDED
|
Binary file (7.31 kB). View file
|
|
|
InfiniDepth/model/block/prompt_models/utils/__pycache__/transformer.cpython-311.pyc
ADDED
|
Binary file (12.1 kB). View file
|
|
|
InfiniDepth/model/block/prompt_models/utils/pe_utils.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Optional, Tuple
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class PositionEmbeddingRandom(nn.Module):
|
| 8 |
+
"""
|
| 9 |
+
Positional encoding using random spatial frequencies.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
patch_size = 14
|
| 13 |
+
|
| 14 |
+
def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None, image_pe_method: str = "patch") -> None:
|
| 15 |
+
super().__init__()
|
| 16 |
+
if scale is None or scale <= 0.0:
|
| 17 |
+
scale = 1.0
|
| 18 |
+
self.register_buffer(
|
| 19 |
+
"positional_encoding_gaussian_matrix",
|
| 20 |
+
scale * torch.randn((2, num_pos_feats)),
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
self.image_pe_method = image_pe_method
|
| 24 |
+
if self.image_pe_method == "image":
|
| 25 |
+
# self.patch_embed = nn.Conv2d(num_pos_feats*2, num_pos_feats*2, kernel_size=self.patch_size, stride=self.patch_size)
|
| 26 |
+
self.patch_embed = nn.Sequential(
|
| 27 |
+
nn.Conv2d(num_pos_feats * 2, num_pos_feats // 2, kernel_size=2, stride=2),
|
| 28 |
+
nn.ReLU(),
|
| 29 |
+
nn.Conv2d(
|
| 30 |
+
num_pos_feats // 2,
|
| 31 |
+
num_pos_feats * 2,
|
| 32 |
+
kernel_size=self.patch_size // 2,
|
| 33 |
+
stride=self.patch_size // 2,
|
| 34 |
+
),
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
|
| 38 |
+
"""Positionally encode points that are normalized to [0,1]."""
|
| 39 |
+
# assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
|
| 40 |
+
coords = 2 * coords - 1
|
| 41 |
+
coords = coords @ self.positional_encoding_gaussian_matrix
|
| 42 |
+
coords = 2 * np.pi * coords
|
| 43 |
+
# outputs d_1 x ... x d_n x C shape
|
| 44 |
+
return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
|
| 45 |
+
|
| 46 |
+
def forward_encoding(self, size: Tuple[int, int]) -> torch.Tensor:
|
| 47 |
+
"""Generate positional encoding for a grid of the specified size."""
|
| 48 |
+
h, w = size
|
| 49 |
+
device: Any = self.positional_encoding_gaussian_matrix.device
|
| 50 |
+
grid = torch.ones((h, w), device=device, dtype=torch.float32)
|
| 51 |
+
y_embed = grid.cumsum(dim=0) - 0.5
|
| 52 |
+
x_embed = grid.cumsum(dim=1) - 0.5
|
| 53 |
+
y_embed = y_embed / h
|
| 54 |
+
x_embed = x_embed / w
|
| 55 |
+
pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1)) # HxWx2 -> HxWxC
|
| 56 |
+
return pe.permute(2, 0, 1) # C x H x W
|
| 57 |
+
|
| 58 |
+
def forward(self, size: Tuple[int, int]) -> torch.Tensor:
|
| 59 |
+
|
| 60 |
+
if self.image_pe_method == "patch":
|
| 61 |
+
return self.forward_encoding(size)
|
| 62 |
+
elif self.image_pe_method == "image":
|
| 63 |
+
pe_encoding = self.forward_encoding(size)
|
| 64 |
+
pe_encoding_high = self.forward_encoding((size[0] * self.patch_size, size[1] * self.patch_size))
|
| 65 |
+
return pe_encoding + self.patch_embed(pe_encoding_high[None])[0]
|
| 66 |
+
|
| 67 |
+
def forward_with_coords(self, coords_input: torch.Tensor, image_size: Tuple[int, int]) -> torch.Tensor:
|
| 68 |
+
"""Positionally encode points that are not normalized to [0,1]."""
|
| 69 |
+
coords = coords_input.clone()
|
| 70 |
+
coords[:, :, 0] = coords[:, :, 0] / image_size[1]
|
| 71 |
+
coords[:, :, 1] = coords[:, :, 1] / image_size[0]
|
| 72 |
+
return self._pe_encoding(coords.to(torch.float)) # B x N x C
|
InfiniDepth/model/block/prompt_models/utils/transformer.py
ADDED
|
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import math
|
| 8 |
+
from typing import Tuple, Type
|
| 9 |
+
import torch
|
| 10 |
+
from einops import rearrange
|
| 11 |
+
from ...common import MLPBlock
|
| 12 |
+
from .....utils.logger import Log
|
| 13 |
+
from torch import Tensor, nn
|
| 14 |
+
|
| 15 |
+
try:
|
| 16 |
+
from xformers.ops import memory_efficient_attention
|
| 17 |
+
|
| 18 |
+
XFORMERS_AVAILABLE = True
|
| 19 |
+
except ImportError:
|
| 20 |
+
Log.warning("xFormers not available")
|
| 21 |
+
XFORMERS_AVAILABLE = False
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class TwoWayTransformer(nn.Module):
|
| 25 |
+
def __init__(
|
| 26 |
+
self,
|
| 27 |
+
depth: int,
|
| 28 |
+
embedding_dim: int,
|
| 29 |
+
num_heads: int,
|
| 30 |
+
mlp_dim: int,
|
| 31 |
+
activation: Type[nn.Module] = nn.ReLU,
|
| 32 |
+
attention_downsample_rate: int = 2,
|
| 33 |
+
) -> None:
|
| 34 |
+
"""
|
| 35 |
+
A transformer decoder that attends to an input image using
|
| 36 |
+
queries whose positional embedding is supplied.
|
| 37 |
+
|
| 38 |
+
Args:
|
| 39 |
+
depth (int): number of layers in the transformer
|
| 40 |
+
embedding_dim (int): the channel dimension for the input embeddings
|
| 41 |
+
num_heads (int): the number of heads for multihead attention. Must
|
| 42 |
+
divide embedding_dim
|
| 43 |
+
mlp_dim (int): the channel dimension internal to the MLP block
|
| 44 |
+
activation (nn.Module): the activation to use in the MLP block
|
| 45 |
+
"""
|
| 46 |
+
super().__init__()
|
| 47 |
+
self.depth = depth
|
| 48 |
+
self.embedding_dim = embedding_dim
|
| 49 |
+
self.num_heads = num_heads
|
| 50 |
+
self.mlp_dim = mlp_dim
|
| 51 |
+
self.layers = nn.ModuleList()
|
| 52 |
+
|
| 53 |
+
for i in range(depth):
|
| 54 |
+
self.layers.append(
|
| 55 |
+
TwoWayAttentionBlock(
|
| 56 |
+
embedding_dim=embedding_dim,
|
| 57 |
+
num_heads=num_heads,
|
| 58 |
+
mlp_dim=mlp_dim,
|
| 59 |
+
activation=activation,
|
| 60 |
+
attention_downsample_rate=attention_downsample_rate,
|
| 61 |
+
skip_first_layer_pe=(i == 0),
|
| 62 |
+
)
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
self.final_attn_token_to_image = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate)
|
| 66 |
+
self.norm_final_attn = nn.LayerNorm(embedding_dim)
|
| 67 |
+
|
| 68 |
+
def forward(
|
| 69 |
+
self, image_embedding: Tensor, image_pe: Tensor, point_embedding: Tensor, point_pe: Tensor
|
| 70 |
+
) -> Tuple[Tensor, Tensor]:
|
| 71 |
+
"""
|
| 72 |
+
Args:
|
| 73 |
+
image_embedding (torch.Tensor): image to attend to. Should be shape
|
| 74 |
+
B x embedding_dim x h x w for any h and w.
|
| 75 |
+
image_pe (torch.Tensor): the positional encoding to add to the image. Must
|
| 76 |
+
have the same shape as image_embedding.
|
| 77 |
+
point_embedding (torch.Tensor): the embedding to add to the query points.
|
| 78 |
+
Must have shape B x N_points x embedding_dim for any N_points.
|
| 79 |
+
|
| 80 |
+
Returns:
|
| 81 |
+
torch.Tensor: the processed point_embedding
|
| 82 |
+
torch.Tensor: the processed image_embedding
|
| 83 |
+
"""
|
| 84 |
+
# Prepare queries
|
| 85 |
+
queries = point_embedding
|
| 86 |
+
keys = image_embedding
|
| 87 |
+
|
| 88 |
+
# Apply transformer blocks and final layernorm
|
| 89 |
+
for layer in self.layers:
|
| 90 |
+
queries, keys = layer(
|
| 91 |
+
queries=queries,
|
| 92 |
+
keys=keys,
|
| 93 |
+
query_pe=point_pe,
|
| 94 |
+
key_pe=image_pe,
|
| 95 |
+
)
|
| 96 |
+
# queries become keys-aware; keys become queries-aware
|
| 97 |
+
|
| 98 |
+
# Apply the final attention layer from the points to the image
|
| 99 |
+
q = queries + point_embedding
|
| 100 |
+
k = keys + image_pe
|
| 101 |
+
attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
|
| 102 |
+
queries = queries + attn_out
|
| 103 |
+
queries = self.norm_final_attn(queries)
|
| 104 |
+
|
| 105 |
+
return queries, keys
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
class TwoWayAttentionBlock(nn.Module):
|
| 109 |
+
def __init__(
|
| 110 |
+
self,
|
| 111 |
+
embedding_dim: int,
|
| 112 |
+
num_heads: int,
|
| 113 |
+
mlp_dim: int = 2048,
|
| 114 |
+
activation: Type[nn.Module] = nn.ReLU,
|
| 115 |
+
attention_downsample_rate: int = 2,
|
| 116 |
+
skip_first_layer_pe: bool = False,
|
| 117 |
+
) -> None:
|
| 118 |
+
"""
|
| 119 |
+
A transformer block with four layers: (1) self-attention of sparse
|
| 120 |
+
inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp
|
| 121 |
+
block on sparse inputs, and (4) cross attention of dense inputs to sparse
|
| 122 |
+
inputs.
|
| 123 |
+
|
| 124 |
+
Arguments:
|
| 125 |
+
embedding_dim (int): the channel dimension of the embeddings
|
| 126 |
+
num_heads (int): the number of heads in the attention layers
|
| 127 |
+
mlp_dim (int): the hidden dimension of the mlp block
|
| 128 |
+
activation (nn.Module): the activation of the mlp block
|
| 129 |
+
skip_first_layer_pe (bool): skip the PE on the first layer
|
| 130 |
+
"""
|
| 131 |
+
super().__init__()
|
| 132 |
+
self.self_attn = MemEffAttention(embedding_dim, num_heads)
|
| 133 |
+
self.norm1 = nn.LayerNorm(embedding_dim)
|
| 134 |
+
|
| 135 |
+
self.cross_attn_token_to_image = MemEffAttention(
|
| 136 |
+
embedding_dim, num_heads, downsample_rate=attention_downsample_rate
|
| 137 |
+
)
|
| 138 |
+
self.norm2 = nn.LayerNorm(embedding_dim)
|
| 139 |
+
|
| 140 |
+
self.mlp = MLPBlock(embedding_dim, mlp_dim, activation)
|
| 141 |
+
self.norm3 = nn.LayerNorm(embedding_dim)
|
| 142 |
+
|
| 143 |
+
self.norm4 = nn.LayerNorm(embedding_dim)
|
| 144 |
+
self.cross_attn_image_to_token = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate)
|
| 145 |
+
|
| 146 |
+
self.skip_first_layer_pe = skip_first_layer_pe
|
| 147 |
+
|
| 148 |
+
def forward(self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor) -> Tuple[Tensor, Tensor]:
|
| 149 |
+
# Self attention block
|
| 150 |
+
#
|
| 151 |
+
if self.skip_first_layer_pe:
|
| 152 |
+
queries = self.self_attn(q=queries, k=queries, v=queries)
|
| 153 |
+
else:
|
| 154 |
+
q = queries + query_pe
|
| 155 |
+
attn_out = self.self_attn(q=q, k=q, v=queries)
|
| 156 |
+
queries = queries + attn_out
|
| 157 |
+
queries = self.norm1(queries)
|
| 158 |
+
|
| 159 |
+
# Cross attention block, tokens attending to image embedding
|
| 160 |
+
q = queries + query_pe
|
| 161 |
+
k = keys + key_pe
|
| 162 |
+
attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
|
| 163 |
+
queries = queries + attn_out
|
| 164 |
+
queries = self.norm2(queries)
|
| 165 |
+
|
| 166 |
+
# MLP block
|
| 167 |
+
mlp_out = self.mlp(queries)
|
| 168 |
+
queries = queries + mlp_out
|
| 169 |
+
queries = self.norm3(queries)
|
| 170 |
+
|
| 171 |
+
# Cross attention block, image embedding attending to tokens
|
| 172 |
+
q = queries + query_pe
|
| 173 |
+
k = keys + key_pe
|
| 174 |
+
attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
|
| 175 |
+
keys = keys + attn_out
|
| 176 |
+
keys = self.norm4(keys)
|
| 177 |
+
|
| 178 |
+
return queries, keys
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
class Attention(nn.Module):
|
| 182 |
+
"""
|
| 183 |
+
An attention layer that allows for downscaling the size of the embedding
|
| 184 |
+
after projection to queries, keys, and values.
|
| 185 |
+
"""
|
| 186 |
+
|
| 187 |
+
def __init__(
|
| 188 |
+
self,
|
| 189 |
+
embedding_dim: int,
|
| 190 |
+
num_heads: int,
|
| 191 |
+
downsample_rate: int = 1,
|
| 192 |
+
) -> None:
|
| 193 |
+
super().__init__()
|
| 194 |
+
self.embedding_dim = embedding_dim
|
| 195 |
+
self.internal_dim = embedding_dim // downsample_rate
|
| 196 |
+
self.num_heads = num_heads
|
| 197 |
+
assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim."
|
| 198 |
+
|
| 199 |
+
self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
|
| 200 |
+
self.k_proj = nn.Linear(embedding_dim, self.internal_dim)
|
| 201 |
+
self.v_proj = nn.Linear(embedding_dim, self.internal_dim)
|
| 202 |
+
self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
|
| 203 |
+
|
| 204 |
+
def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
|
| 205 |
+
b, n, c = x.shape
|
| 206 |
+
x = x.reshape(b, n, num_heads, c // num_heads)
|
| 207 |
+
return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head
|
| 208 |
+
|
| 209 |
+
def _recombine_heads(self, x: Tensor) -> Tensor:
|
| 210 |
+
b, n_heads, n_tokens, c_per_head = x.shape
|
| 211 |
+
x = x.transpose(1, 2)
|
| 212 |
+
return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
|
| 213 |
+
|
| 214 |
+
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
|
| 215 |
+
# Input projections
|
| 216 |
+
q = self.q_proj(q)
|
| 217 |
+
k = self.k_proj(k)
|
| 218 |
+
v = self.v_proj(v)
|
| 219 |
+
|
| 220 |
+
# Separate into heads
|
| 221 |
+
q = self._separate_heads(q, self.num_heads)
|
| 222 |
+
k = self._separate_heads(k, self.num_heads)
|
| 223 |
+
v = self._separate_heads(v, self.num_heads)
|
| 224 |
+
|
| 225 |
+
# Attention
|
| 226 |
+
_, _, _, c_per_head = q.shape
|
| 227 |
+
attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens
|
| 228 |
+
attn = attn / math.sqrt(c_per_head)
|
| 229 |
+
attn = torch.softmax(attn, dim=-1)
|
| 230 |
+
|
| 231 |
+
# Get output
|
| 232 |
+
out = attn @ v
|
| 233 |
+
out = self._recombine_heads(out)
|
| 234 |
+
out = self.out_proj(out)
|
| 235 |
+
|
| 236 |
+
return out
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
class MemEffAttention(Attention):
|
| 240 |
+
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
|
| 241 |
+
q = self.q_proj(q)
|
| 242 |
+
k = self.k_proj(k)
|
| 243 |
+
v = self.v_proj(v)
|
| 244 |
+
q = rearrange(q, "b n (m c) -> b n m c", m=self.num_heads)
|
| 245 |
+
k = rearrange(k, "b n (m c) -> b n m c", m=self.num_heads)
|
| 246 |
+
v = rearrange(v, "b n (m c) -> b n m c", m=self.num_heads)
|
| 247 |
+
x = memory_efficient_attention(q, k, v, attn_bias=None)
|
| 248 |
+
x = rearrange(x, "b n m c -> b n (m c)")
|
| 249 |
+
x = self.out_proj(x)
|
| 250 |
+
return x
|
InfiniDepth/model/block/rope.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Tuple
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0):
|
| 6 |
+
# assert H * H == end
|
| 7 |
+
# flat_patch_pos = torch.linspace(-1, 1, end) # N = end
|
| 8 |
+
x_pos = torch.linspace(0, scale, width)
|
| 9 |
+
y_pos = torch.linspace(0, scale, height)
|
| 10 |
+
y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij")
|
| 11 |
+
y_pos = y_pos.reshape(-1)
|
| 12 |
+
x_pos = x_pos.reshape(-1)
|
| 13 |
+
freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4
|
| 14 |
+
x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4
|
| 15 |
+
y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4
|
| 16 |
+
x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs)
|
| 17 |
+
y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs)
|
| 18 |
+
freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2
|
| 19 |
+
freqs_cis = freqs_cis.reshape(height*width, -1)
|
| 20 |
+
return freqs_cis
|
| 21 |
+
|
| 22 |
+
def precompute_freqs_cis_ex2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=1.0):
|
| 23 |
+
if isinstance(scale, float):
|
| 24 |
+
scale = (scale, scale)
|
| 25 |
+
x_pos = torch.linspace(0, height*scale[0], width)
|
| 26 |
+
y_pos = torch.linspace(0, width*scale[1], height)
|
| 27 |
+
y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij")
|
| 28 |
+
y_pos = y_pos.reshape(-1)
|
| 29 |
+
x_pos = x_pos.reshape(-1)
|
| 30 |
+
freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4
|
| 31 |
+
x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4
|
| 32 |
+
y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4
|
| 33 |
+
x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs)
|
| 34 |
+
y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs)
|
| 35 |
+
freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2
|
| 36 |
+
freqs_cis = freqs_cis.reshape(height*width, -1)
|
| 37 |
+
return freqs_cis
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def apply_rotary_emb(
|
| 41 |
+
xq: torch.Tensor,
|
| 42 |
+
xk: torch.Tensor,
|
| 43 |
+
freqs_cis: torch.Tensor,
|
| 44 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 45 |
+
freqs_cis = freqs_cis[None, None, :, :]
|
| 46 |
+
# xq : B N H Hc
|
| 47 |
+
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2
|
| 48 |
+
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
| 49 |
+
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc
|
| 50 |
+
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
|
| 51 |
+
return xq_out.type_as(xq), xk_out.type_as(xk)
|
| 52 |
+
|
| 53 |
+
def apply_rotary_emb_crossattention(
|
| 54 |
+
xq: torch.Tensor,
|
| 55 |
+
xk: torch.Tensor,
|
| 56 |
+
yk: torch.Tensor,
|
| 57 |
+
freqs_cis1: torch.Tensor,
|
| 58 |
+
freqs_cis2: torch.Tensor,
|
| 59 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 60 |
+
freqs_cis1 = freqs_cis1[None, None, :, :]
|
| 61 |
+
freqs_cis2 = freqs_cis2[None, None, :, :]
|
| 62 |
+
# xq : B N H Hc
|
| 63 |
+
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2
|
| 64 |
+
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
| 65 |
+
yk_ = torch.view_as_complex(yk.float().reshape(*yk.shape[:-1], -1, 2))
|
| 66 |
+
xq_out = torch.view_as_real(xq_ * freqs_cis1).flatten(3) # B, N, H, Hc
|
| 67 |
+
xk_out = torch.view_as_real(xk_ * freqs_cis1).flatten(3)
|
| 68 |
+
yk_out = torch.view_as_real(yk_ * freqs_cis2).flatten(3)
|
| 69 |
+
return xq_out.type_as(xq), xk_out.type_as(xk), yk_out.type_as(yk)
|
InfiniDepth/model/block/torchhub/README.md
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Local PyTorch Hub
|
| 2 |
+
|
| 3 |
+
This directory is for loading the DINOv2 encoder locally in case of no Internet connection.
|
InfiniDepth/model/block/torchhub/dinov3/.docstr.yaml
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
paths:
|
| 2 |
+
- dinov3
|
| 3 |
+
exclude: dinov3/tests
|
| 4 |
+
skip_init: True
|
| 5 |
+
skip_private: True
|
| 6 |
+
fail_under: 0
|
InfiniDepth/model/block/torchhub/dinov3/.github/workflows/lint.yaml
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: Lint
|
| 2 |
+
|
| 3 |
+
on:
|
| 4 |
+
push:
|
| 5 |
+
branches:
|
| 6 |
+
- main
|
| 7 |
+
pull_request:
|
| 8 |
+
branches:
|
| 9 |
+
- main
|
| 10 |
+
|
| 11 |
+
jobs:
|
| 12 |
+
run-linters:
|
| 13 |
+
name: Run linters
|
| 14 |
+
runs-on: ubuntu-latest
|
| 15 |
+
|
| 16 |
+
steps:
|
| 17 |
+
- name: Checkout repository
|
| 18 |
+
uses: actions/checkout@v4
|
| 19 |
+
- name: Set up Python
|
| 20 |
+
uses: actions/setup-python@v5
|
| 21 |
+
with:
|
| 22 |
+
python-version: 3.11
|
| 23 |
+
cache: 'pip'
|
| 24 |
+
cache-dependency-path: '**/requirements*.txt'
|
| 25 |
+
- name: Install Python (development) dependencies
|
| 26 |
+
run: |
|
| 27 |
+
pip install -r requirements-dev.txt
|
| 28 |
+
- name: Run ruff (linter)
|
| 29 |
+
run: |
|
| 30 |
+
ruff check dinov3
|
| 31 |
+
- name: Run ruff (formatter)
|
| 32 |
+
if: always()
|
| 33 |
+
run: |
|
| 34 |
+
ruff format --diff dinov3
|
| 35 |
+
- name: Report docstring coverage
|
| 36 |
+
if: always()
|
| 37 |
+
run: |
|
| 38 |
+
docstr-coverage dinov3
|
| 39 |
+
- name: Run mypy
|
| 40 |
+
if: always()
|
| 41 |
+
run: |
|
| 42 |
+
mypy --txt-report .
|
| 43 |
+
[ -f index.txt ] && cat index.txt
|
| 44 |
+
- name: Run pylint
|
| 45 |
+
if: always()
|
| 46 |
+
run: |
|
| 47 |
+
pylint --exit-zero dinov3
|
InfiniDepth/model/block/torchhub/dinov3/.gitignore
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
build/
|
| 2 |
+
dist/
|
| 3 |
+
*.egg-info/
|
| 4 |
+
**/__pycache__/
|
| 5 |
+
|
| 6 |
+
**/.ipynb_checkpoints
|
| 7 |
+
**/.ipynb_checkpoints/**
|
| 8 |
+
|
| 9 |
+
**/notebooks
|
| 10 |
+
|
| 11 |
+
# Ignore shell scripts
|
| 12 |
+
*.sh
|
| 13 |
+
|
| 14 |
+
# Ignore swap files
|
| 15 |
+
*.swp
|
| 16 |
+
|
| 17 |
+
# Ignore vscode directory
|
| 18 |
+
.vscode/
|
InfiniDepth/model/block/torchhub/dinov3/CODE_OF_CONDUCT.md
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Code of Conduct
|
| 2 |
+
|
| 3 |
+
## Our Pledge
|
| 4 |
+
|
| 5 |
+
In the interest of fostering an open and welcoming environment, we as
|
| 6 |
+
contributors and maintainers pledge to make participation in our project and
|
| 7 |
+
our community a harassment-free experience for everyone, regardless of age, body
|
| 8 |
+
size, disability, ethnicity, sex characteristics, gender identity and expression,
|
| 9 |
+
level of experience, education, socio-economic status, nationality, personal
|
| 10 |
+
appearance, race, religion, or sexual identity and orientation.
|
| 11 |
+
|
| 12 |
+
## Our Standards
|
| 13 |
+
|
| 14 |
+
Examples of behavior that contributes to creating a positive environment
|
| 15 |
+
include:
|
| 16 |
+
|
| 17 |
+
* Using welcoming and inclusive language
|
| 18 |
+
* Being respectful of differing viewpoints and experiences
|
| 19 |
+
* Gracefully accepting constructive criticism
|
| 20 |
+
* Focusing on what is best for the community
|
| 21 |
+
* Showing empathy towards other community members
|
| 22 |
+
|
| 23 |
+
Examples of unacceptable behavior by participants include:
|
| 24 |
+
|
| 25 |
+
* The use of sexualized language or imagery and unwelcome sexual attention or
|
| 26 |
+
advances
|
| 27 |
+
* Trolling, insulting/derogatory comments, and personal or political attacks
|
| 28 |
+
* Public or private harassment
|
| 29 |
+
* Publishing others' private information, such as a physical or electronic
|
| 30 |
+
address, without explicit permission
|
| 31 |
+
* Other conduct which could reasonably be considered inappropriate in a
|
| 32 |
+
professional setting
|
| 33 |
+
|
| 34 |
+
## Our Responsibilities
|
| 35 |
+
|
| 36 |
+
Project maintainers are responsible for clarifying the standards of acceptable
|
| 37 |
+
behavior and are expected to take appropriate and fair corrective action in
|
| 38 |
+
response to any instances of unacceptable behavior.
|
| 39 |
+
|
| 40 |
+
Project maintainers have the right and responsibility to remove, edit, or
|
| 41 |
+
reject comments, commits, code, wiki edits, issues, and other contributions
|
| 42 |
+
that are not aligned to this Code of Conduct, or to ban temporarily or
|
| 43 |
+
permanently any contributor for other behaviors that they deem inappropriate,
|
| 44 |
+
threatening, offensive, or harmful.
|
| 45 |
+
|
| 46 |
+
## Scope
|
| 47 |
+
|
| 48 |
+
This Code of Conduct applies within all project spaces, and it also applies when
|
| 49 |
+
an individual is representing the project or its community in public spaces.
|
| 50 |
+
Examples of representing a project or community include using an official
|
| 51 |
+
project e-mail address, posting via an official social media account, or acting
|
| 52 |
+
as an appointed representative at an online or offline event. Representation of
|
| 53 |
+
a project may be further defined and clarified by project maintainers.
|
| 54 |
+
|
| 55 |
+
This Code of Conduct also applies outside the project spaces when there is a
|
| 56 |
+
reasonable belief that an individual's behavior may have a negative impact on
|
| 57 |
+
the project or its community.
|
| 58 |
+
|
| 59 |
+
## Enforcement
|
| 60 |
+
|
| 61 |
+
Instances of abusive, harassing, or otherwise unacceptable behavior may be
|
| 62 |
+
reported by contacting the project team at <opensource-conduct@meta.com>. All
|
| 63 |
+
complaints will be reviewed and investigated and will result in a response that
|
| 64 |
+
is deemed necessary and appropriate to the circumstances. The project team is
|
| 65 |
+
obligated to maintain confidentiality with regard to the reporter of an incident.
|
| 66 |
+
Further details of specific enforcement policies may be posted separately.
|
| 67 |
+
|
| 68 |
+
Project maintainers who do not follow or enforce the Code of Conduct in good
|
| 69 |
+
faith may face temporary or permanent repercussions as determined by other
|
| 70 |
+
members of the project's leadership.
|
| 71 |
+
|
| 72 |
+
## Attribution
|
| 73 |
+
|
| 74 |
+
This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
|
| 75 |
+
available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
|
| 76 |
+
|
| 77 |
+
[homepage]: https://www.contributor-covenant.org
|
| 78 |
+
|
| 79 |
+
For answers to common questions about this code of conduct, see
|
| 80 |
+
https://www.contributor-covenant.org/faq
|
InfiniDepth/model/block/torchhub/dinov3/CONTRIBUTING.md
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Contributing to DINOv3
|
| 2 |
+
We want to make contributing to this project as easy and transparent as
|
| 3 |
+
possible.
|
| 4 |
+
|
| 5 |
+
## Pull Requests
|
| 6 |
+
We actively welcome your pull requests.
|
| 7 |
+
|
| 8 |
+
1. Fork the repo and create your branch from `main`.
|
| 9 |
+
2. If you've added code that should be tested, add tests.
|
| 10 |
+
3. If you've changed APIs, update the documentation.
|
| 11 |
+
4. Ensure the test suite passes.
|
| 12 |
+
5. Make sure your code lints.
|
| 13 |
+
6. If you haven't already, complete the Contributor License Agreement ("CLA").
|
| 14 |
+
|
| 15 |
+
## Contributor License Agreement ("CLA")
|
| 16 |
+
In order to accept your pull request, we need you to submit a CLA. You only need
|
| 17 |
+
to do this once to work on any of Meta's open source projects.
|
| 18 |
+
|
| 19 |
+
Complete your CLA here: <https://code.facebook.com/cla>
|
| 20 |
+
|
| 21 |
+
## Issues
|
| 22 |
+
We use GitHub issues to track public bugs. Please ensure your description is
|
| 23 |
+
clear and has sufficient instructions to be able to reproduce the issue.
|
| 24 |
+
|
| 25 |
+
Meta has a [bounty program](https://www.facebook.com/whitehat/) for the safe
|
| 26 |
+
disclosure of security bugs. In those cases, please go through the process
|
| 27 |
+
outlined on that page and do not file a public issue.
|
| 28 |
+
|
| 29 |
+
## License
|
| 30 |
+
By contributing to DINOv3, you agree that your contributions will be licensed
|
| 31 |
+
under the LICENSE.md file in the root directory of this source tree.
|
InfiniDepth/model/block/torchhub/dinov3/LICENSE.md
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# DINOv3 License
|
| 2 |
+
|
| 3 |
+
*Last Updated: August 19, 2025*
|
| 4 |
+
|
| 5 |
+
**“Agreement”** means the terms and conditions for use, reproduction, distribution and modification of the DINO Materials set forth herein.
|
| 6 |
+
|
| 7 |
+
**“DINO Materials”** means, collectively, Documentation and the models, software and algorithms, including machine-learning model code, trained model weights, inference-enabling code, training-enabling code, fine-tuning enabling code, and other elements of the foregoing distributed by Meta and made available under this Agreement.
|
| 8 |
+
|
| 9 |
+
**“Documentation”** means the specifications, manuals and documentation accompanying
|
| 10 |
+
DINO Materials distributed by Meta.
|
| 11 |
+
|
| 12 |
+
**“Licensee”** or **“you”** means you, or your employer or any other person or entity (if you are entering into this Agreement on such person or entity’s behalf), of the age required under applicable laws, rules or regulations to provide legal consent and that has legal authority to bind your employer or such other person or entity if you are entering in this Agreement on their behalf.
|
| 13 |
+
|
| 14 |
+
**“Meta”** or **“we”** means Meta Platforms Ireland Limited (if you are located in or, if you are an entity, your principal place of business is in the EEA or Switzerland) or Meta Platforms, Inc. (if you are located outside of the EEA or Switzerland).
|
| 15 |
+
|
| 16 |
+
**“Sanctions”** means any economic or trade sanctions or restrictions administered or enforced by the United States (including the Office of Foreign Assets Control of the U.S. Department of the Treasury (“OFAC”), the U.S. Department of State and the U.S. Department of Commerce), the United Nations, the European Union, or the United Kingdom.
|
| 17 |
+
|
| 18 |
+
**“Trade Controls”** means any of the following: Sanctions and applicable export and import controls.
|
| 19 |
+
|
| 20 |
+
By clicking “I Accept” below or by using or distributing any portion or element of the DINO Materials, you agree to be bound by this Agreement.
|
| 21 |
+
|
| 22 |
+
## 1. License Rights and Redistribution.
|
| 23 |
+
|
| 24 |
+
a. <ins>Grant of Rights</ins>. You are granted a non-exclusive, worldwide, non-transferable and royalty-free limited license under Meta’s intellectual property or other rights owned by Meta embodied in the DINO Materials to use, reproduce, distribute, copy, create derivative works of, and make modifications to the DINO Materials.
|
| 25 |
+
|
| 26 |
+
b. <ins>Redistribution and Use</ins>.
|
| 27 |
+
|
| 28 |
+
i. Distribution of DINO Materials, and any derivative works thereof, are subject to the terms of this Agreement. If you distribute or make the DINO Materials, or any derivative works thereof, available to a third party, you may only do so under the terms of this Agreement and you shall provide a copy of this Agreement with any such DINO Materials.
|
| 29 |
+
|
| 30 |
+
ii. If you submit for publication the results of research you perform on, using, or otherwise in connection with DINO Materials, you must acknowledge the use of DINO Materials in your publication.
|
| 31 |
+
|
| 32 |
+
iii. Your use of the DINO Materials must comply with applicable laws and regulations, including Trade Control Laws and applicable privacy and data protection laws.
|
| 33 |
+
|
| 34 |
+
iv. Your use of the DINO Materials will not involve or encourage others to reverse engineer, decompile or discover the underlying components of the DINO Materials.
|
| 35 |
+
|
| 36 |
+
v. You are not the target of Trade Controls and your use of DINO Materials must comply with Trade Controls. You agree not to use, or permit others to use, DINO Materials for any activities subject to the International Traffic in Arms Regulations (ITAR) or end uses prohibited by Trade Controls, including those related to military or warfare purposes, nuclear industries or applications, espionage, or the development or use of guns or illegal weapons.
|
| 37 |
+
|
| 38 |
+
## 2. User Support.
|
| 39 |
+
|
| 40 |
+
Your use of the DINO Materials is done at your own discretion; Meta does not process any information nor provide any service in relation to such use. Meta is under no obligation to provide any support services for the DINO Materials. Any support provided is “as is”, “with all faults”, and without warranty of any kind.
|
| 41 |
+
|
| 42 |
+
## 3. Disclaimer of Warranty.
|
| 43 |
+
|
| 44 |
+
UNLESS REQUIRED BY APPLICABLE LAW, THE DINO MATERIALS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED ON AN “AS IS” BASIS, WITHOUT WARRANTIES OF ANY KIND, AND META DISCLAIMS ALL WARRANTIES OF ANY KIND, BOTH EXPRESS AND IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING OR REDISTRIBUTING THE DINO MATERIALS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR USE OF THE DINO MATERIALS AND ANY OUTPUT AND RESULTS.
|
| 45 |
+
|
| 46 |
+
## 4. Limitation of Liability.
|
| 47 |
+
|
| 48 |
+
IN NO EVENT WILL META OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS AGREEMENT, FOR ANY LOST PROFITS OR ANY DIRECT OR INDIRECT, SPECIAL, CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN IF META OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING.
|
| 49 |
+
|
| 50 |
+
## 5. Intellectual Property.
|
| 51 |
+
|
| 52 |
+
a. Subject to Meta’s ownership of DINO Materials and derivatives made by or for Meta, with respect to any derivative works and modifications of the DINO Materials that are made by you, as between you and Meta, you are and will be the owner of such derivative works and modifications.
|
| 53 |
+
|
| 54 |
+
b. If you institute litigation or other proceedings against Meta or any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the DINO Materials, outputs or results, or any portion of any of the foregoing, constitutes infringement of intellectual property or other rights owned or licensable by you, then any licenses granted to you under this Agreement shall terminate as of the date such litigation or claim is filed or instituted. You will indemnify and hold harmless Meta from and against any claim by any third party arising out of or related to your use or distribution of the DINO Materials.
|
| 55 |
+
|
| 56 |
+
## 6. Term and Termination.
|
| 57 |
+
|
| 58 |
+
The term of this Agreement will commence upon your acceptance of this Agreement or access to the DINO Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein. Meta may terminate this Agreement if you are in breach of any term or condition of this Agreement. Upon termination of this Agreement, you shall delete and cease use of the DINO Materials. Sections 3, 4 and 7 shall survive the termination of this Agreement.
|
| 59 |
+
|
| 60 |
+
## 7. Governing Law and Jurisdiction.
|
| 61 |
+
|
| 62 |
+
This Agreement will be governed and construed under the laws of the State of California without regard to choice of law principles, and the UN Convention on Contracts for the International Sale of Goods does not apply to this Agreement. The courts of California shall have exclusive jurisdiction of any dispute arising out of this Agreement.
|
| 63 |
+
|
| 64 |
+
## 8. Modifications and Amendments.
|
| 65 |
+
|
| 66 |
+
Meta may modify this Agreement from time to time; provided that they are similar in spirit to the current version of the Agreement, but may differ in detail to address new problems or concerns. All such changes will be effective immediately. Your continued use of the DINO Materials after any modification to this Agreement constitutes your agreement to such modification. Except as provided in this Agreement, no modification or addition to any provision of this Agreement will be binding unless it is in writing and signed by an authorized representative of both you and Meta.
|
InfiniDepth/model/block/torchhub/dinov3/MODEL_CARD.md
ADDED
|
@@ -0,0 +1,432 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Model Card for DINOv3
|
| 2 |
+
|
| 3 |
+
DINOv3 is a family of versatile vision foundation models that outperforms the specialized state of the art across a broad range of settings, without fine-tuning. DINOv3 produces high-quality dense features that achieve outstanding performance on various vision tasks, significantly surpassing previous self- and weakly-supervised foundation models.
|
| 4 |
+
|
| 5 |
+
## Model Details
|
| 6 |
+
|
| 7 |
+
These are Vision Transformer and ConvNeXt models trained following the method described in the DINOv3 paper. 12 models are provided:
|
| 8 |
+
|
| 9 |
+
- 10 models pretrained on web data (LVD-1689M dataset)
|
| 10 |
+
- 1 ViT-7B trained from scratch,
|
| 11 |
+
- 5 ViT-S/S+/B/L/H+ models distilled from the ViT-7B,
|
| 12 |
+
- 4 ConvNeXt-{T/S/B/L} models distilled from the ViT-7B,
|
| 13 |
+
- 2 models pretrained on satellite data (SAT-493M dataset)
|
| 14 |
+
- 1 ViT-7B trained from scratch
|
| 15 |
+
- 1 ViT-L distilled from the ViT-7B
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
Each Transformer-based model takes an image as input and returns a class token, patch tokens (and register tokens). These models follow a ViT architecture, with a patch size of 16. For a 224x224 image, this results in 1 class token + 4 register tokens + 196 patch tokens = 201 tokens (for DINOv2 with registers this resulted in 1 + 4 + 256 = 261 tokens).
|
| 19 |
+
|
| 20 |
+
The models can accept larger images provided the image shapes are multiples of the patch size (16). If this condition is not verified, the model will crop to the closest smaller multiple of the patch size.
|
| 21 |
+
|
| 22 |
+
### Model Description
|
| 23 |
+
|
| 24 |
+
- **Developed by:** Meta AI
|
| 25 |
+
- **Model type:** Vision Transformer, ConvNeXt
|
| 26 |
+
- **License:** [DINOv3 License](https://ai.meta.com/resources/models-and-libraries/dinov3-license/)
|
| 27 |
+
|
| 28 |
+
### Model Sources
|
| 29 |
+
|
| 30 |
+
- **Repository:** [https://github.com/facebookresearch/dinov3](https://github.com/facebookresearch/dinov3)
|
| 31 |
+
- **Paper:** [https://arxiv.org/abs/2508.10104](https://arxiv.org/abs/2508.10104)
|
| 32 |
+
|
| 33 |
+
## Uses
|
| 34 |
+
|
| 35 |
+
The models are vision backbones providing multi-purpose features for downstream tasks.
|
| 36 |
+
|
| 37 |
+
### Direct Use
|
| 38 |
+
|
| 39 |
+
The models can be used without fine-tuning, with downstream classifiers as simple as linear layers, to obtain competitive results:
|
| 40 |
+
|
| 41 |
+
- on image classification, using k-NN classifiers on the class token
|
| 42 |
+
- on image classification, with logistic regression classifiers applied on the class token
|
| 43 |
+
- on image classification, with a linear layer applied on the class token and the average of the patch tokens
|
| 44 |
+
- on image retrieval using nearest neighbors
|
| 45 |
+
- on geometric and semantic 3D keypoint correspondances
|
| 46 |
+
- on depth estimation, semantic segmentation, using linear layers
|
| 47 |
+
- on unsupervised object discovery
|
| 48 |
+
- on video segmentation tracking
|
| 49 |
+
- on video classification, using a small 4-layer attentive probe
|
| 50 |
+
|
| 51 |
+
### Downstream Use
|
| 52 |
+
|
| 53 |
+
While fine-tuning the models can yield some gains, it is recommended to keep this option as a last resort: the frozen features are expected to provide good performance out-of-the-box.
|
| 54 |
+
|
| 55 |
+
## Bias, Risks, and Limitations
|
| 56 |
+
|
| 57 |
+
Compared to DINOv2 and SEERv2, DINOv3 delivers somewhat consistent performance across income categories on geographical fairness and diversity, although with a notable performance drop in the low-income bucket compared to the highest-income bucket.
|
| 58 |
+
|
| 59 |
+
DINOv3 also achieves relatively good scores across different regions, improving over its predecessor DINOv2. However, a relative difference is still observed between Europe and Africa.
|
| 60 |
+
|
| 61 |
+
### Recommendations
|
| 62 |
+
|
| 63 |
+
Fine-tuning is expected to increase the biases in the features produced by the model as they will be tuned to the fine-tuning labels.
|
| 64 |
+
|
| 65 |
+
## How to Get Started with the Model
|
| 66 |
+
|
| 67 |
+
Use the code below to get started with the model.
|
| 68 |
+
|
| 69 |
+
```python
|
| 70 |
+
import torch
|
| 71 |
+
|
| 72 |
+
model = torch.hub.load(
|
| 73 |
+
repo_or_dir='facebookresearch/dinov3',
|
| 74 |
+
model='<MODEL_NAME>',
|
| 75 |
+
weights='<PATH/OR/URL/TO/CHECKPOINT>',
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
# where MODEL_NAME can be one of:
|
| 79 |
+
# - dinov3_vits16
|
| 80 |
+
# - dinov3_vits16plus
|
| 81 |
+
# - dinov3_vitb16
|
| 82 |
+
# - dinov3_vitl16
|
| 83 |
+
# - dinov3_vith16plus
|
| 84 |
+
# - dinov3_vit7b16
|
| 85 |
+
# - dinov3_convnext_tiny
|
| 86 |
+
# - dinov3_convnext_small
|
| 87 |
+
# - dinov3_convnext_base
|
| 88 |
+
# - dinov3_convnext_large
|
| 89 |
+
|
| 90 |
+
# For instance
|
| 91 |
+
dinov3_vits16 = torch.hub.load(
|
| 92 |
+
repo_or_dir='facebookresearch/dinov3',
|
| 93 |
+
model='dinov3_vits16',
|
| 94 |
+
weights='<PATH/OR/URL/TO/DINOV3/VITS16/LVD1689M/CHECKPOINT>',
|
| 95 |
+
)
|
| 96 |
+
```
|
| 97 |
+
|
| 98 |
+
## Training Details
|
| 99 |
+
|
| 100 |
+
### Training Data
|
| 101 |
+
|
| 102 |
+
- Web dataset (LVD-1689M): a curated dataset of 1,689 millions of images extracted from a large data
|
| 103 |
+
pool of 17 billions web images collected from public posts on Instagram
|
| 104 |
+
|
| 105 |
+
- Satellite dataset (SAT-493M): a dataset of 493 millions of 512x512 images sampled randomly from Maxar RGB ortho-rectified imagery at 0.6 meter resolution
|
| 106 |
+
|
| 107 |
+
### Training Procedure
|
| 108 |
+
|
| 109 |
+
**Training objective:**
|
| 110 |
+
|
| 111 |
+
- DINO self-distillation loss with multi-crop
|
| 112 |
+
- iBOT masked-image modeling loss
|
| 113 |
+
- KoLeo regularization on [CLS] tokens
|
| 114 |
+
- Gram anchoring
|
| 115 |
+
|
| 116 |
+
- **Training regime:** PyTorch FSDP2 (with bf16 and fp8 matrix multiplications)
|
| 117 |
+
|
| 118 |
+
**Distillation:**
|
| 119 |
+
|
| 120 |
+
- Distillation follows the standard DINOv3 pretraining procedure, except the teacher is a frozen pretrained ViT-7B.
|
| 121 |
+
|
| 122 |
+
## Evaluation
|
| 123 |
+
|
| 124 |
+
**Results**
|
| 125 |
+
|
| 126 |
+
The reader is referred to the associated paper for details on the evaluation protocols
|
| 127 |
+
|
| 128 |
+
*Results for ViT backbones pretrained (or distilled) on web (LVD-1689M)*
|
| 129 |
+
|
| 130 |
+
<table>
|
| 131 |
+
<tr>
|
| 132 |
+
<th></th>
|
| 133 |
+
<!-- <th></th> -->
|
| 134 |
+
<th colspan="4">Global Tasks</th>
|
| 135 |
+
<th colspan="5">Dense Tasks</th>
|
| 136 |
+
</tr>
|
| 137 |
+
<tr>
|
| 138 |
+
<th>Model</th>
|
| 139 |
+
<!-- <th>Dataset</th> -->
|
| 140 |
+
<th>IN-ReaL</th>
|
| 141 |
+
<th>IN-R</th>
|
| 142 |
+
<th>Obj.Net</th>
|
| 143 |
+
<th>Ox.-H</th>
|
| 144 |
+
<th>ADE20k</th>
|
| 145 |
+
<th>NYU↓</th>
|
| 146 |
+
<th>DAVIS</th>
|
| 147 |
+
<th>NAVI</th>
|
| 148 |
+
<th>SPair</th>
|
| 149 |
+
</tr>
|
| 150 |
+
<tr>
|
| 151 |
+
<td>DINOv3 ViT-S/16</td>
|
| 152 |
+
<!-- <td>LVD-1689M</td> -->
|
| 153 |
+
<td align="right">87.0</td>
|
| 154 |
+
<td align="right">60.4</td>
|
| 155 |
+
<td align="right">50.9</td>
|
| 156 |
+
<td align="right">49.5</td>
|
| 157 |
+
<td align="right">47.0</td>
|
| 158 |
+
<td align="right">0.403</td>
|
| 159 |
+
<td align="right">72.7</td>
|
| 160 |
+
<td align="right">56.3</td>
|
| 161 |
+
<td align="right">50.4</td>
|
| 162 |
+
</tr>
|
| 163 |
+
<tr>
|
| 164 |
+
<td>DINOv3 ViT-S+/16</td>
|
| 165 |
+
<!-- <td>LVD-1689M</td> -->
|
| 166 |
+
<td align="right">88.0</td>
|
| 167 |
+
<td align="right">68.8</td>
|
| 168 |
+
<td align="right">54.6</td>
|
| 169 |
+
<td align="right">50.0</td>
|
| 170 |
+
<td align="right">48.8</td>
|
| 171 |
+
<td align="right">0.399</td>
|
| 172 |
+
<td align="right">75.5</td>
|
| 173 |
+
<td align="right">57.1</td>
|
| 174 |
+
<td align="right">55.2</td>
|
| 175 |
+
</tr>
|
| 176 |
+
<tr>
|
| 177 |
+
<td>DINOv3 ViT-B/16</td>
|
| 178 |
+
<!-- <td>LVD-1689M</td> -->
|
| 179 |
+
<td align="right">89.3</td>
|
| 180 |
+
<td align="right">76.7</td>
|
| 181 |
+
<td align="right">64.1</td>
|
| 182 |
+
<td align="right">58.5</td>
|
| 183 |
+
<td align="right">51.8</td>
|
| 184 |
+
<td align="right">0.373</td>
|
| 185 |
+
<td align="right">77.2</td>
|
| 186 |
+
<td align="right">58.8</td>
|
| 187 |
+
<td align="right">57.2</td>
|
| 188 |
+
</tr>
|
| 189 |
+
<tr>
|
| 190 |
+
<td>DINOv3 ViT-L/16</td>
|
| 191 |
+
<!-- <td>LVD-1689M</td> -->
|
| 192 |
+
<td align="right">90.2</td>
|
| 193 |
+
<td align="right">88.1</td>
|
| 194 |
+
<td align="right">74.8</td>
|
| 195 |
+
<td align="right">63.1</td>
|
| 196 |
+
<td align="right">54.9</td>
|
| 197 |
+
<td align="right">0.352</td>
|
| 198 |
+
<td align="right">79.9</td>
|
| 199 |
+
<td align="right">62.3</td>
|
| 200 |
+
<td align="right">61.3</td>
|
| 201 |
+
</tr>
|
| 202 |
+
<tr>
|
| 203 |
+
<td>DINOv3 ViT-H+/16</td>
|
| 204 |
+
<!-- <td>LVD-1689M</td> -->
|
| 205 |
+
<td align="right">90.3</td>
|
| 206 |
+
<td align="right">90.0</td>
|
| 207 |
+
<td align="right">78.6</td>
|
| 208 |
+
<td align="right">64.5</td>
|
| 209 |
+
<td align="right">54.8</td>
|
| 210 |
+
<td align="right">0.352</td>
|
| 211 |
+
<td align="right">79.3</td>
|
| 212 |
+
<td align="right">63.3</td>
|
| 213 |
+
<td align="right">56.3</td>
|
| 214 |
+
</tr>
|
| 215 |
+
<tr>
|
| 216 |
+
<td>DINOv3 ViT-7B/16</td>
|
| 217 |
+
<!-- <td>LVD-1689M</td> -->
|
| 218 |
+
<td align="right">90.4</td>
|
| 219 |
+
<td align="right">91.1</td>
|
| 220 |
+
<td align="right">91.1</td>
|
| 221 |
+
<td align="right">72.8</td>
|
| 222 |
+
<td align="right">55.9</td>
|
| 223 |
+
<td align="right">0.309</td>
|
| 224 |
+
<td align="right">79.7</td>
|
| 225 |
+
<td align="right">64.4</td>
|
| 226 |
+
<td align="right">58.7</td>
|
| 227 |
+
</tr>
|
| 228 |
+
</table>
|
| 229 |
+
|
| 230 |
+
*Results for ConvNeXt backbones distilled on web (LVD-1689M)*
|
| 231 |
+
|
| 232 |
+
<table>
|
| 233 |
+
<tr>
|
| 234 |
+
<th></th>
|
| 235 |
+
<th colspan="6">Global Tasks</th>
|
| 236 |
+
<th colspan="2">Dense Tasks</th>
|
| 237 |
+
</tr>
|
| 238 |
+
<tr>
|
| 239 |
+
<th>Model</th>
|
| 240 |
+
<th colspan="2">IN-ReaL</th>
|
| 241 |
+
<th colspan="2">IN-R</th>
|
| 242 |
+
<th colspan="2">Obj.Net</th>
|
| 243 |
+
<th>ADE20k</th>
|
| 244 |
+
<th>NYU↓</th>
|
| 245 |
+
</tr>
|
| 246 |
+
<tr>
|
| 247 |
+
<td></th>
|
| 248 |
+
<td>@256px</td>
|
| 249 |
+
<td>@512px</td>
|
| 250 |
+
<td>@256px</td>
|
| 251 |
+
<td>@512px</td>
|
| 252 |
+
<td>@256px</td>
|
| 253 |
+
<td>@512px</td>
|
| 254 |
+
<td colspan="2"></td>
|
| 255 |
+
</tr>
|
| 256 |
+
<tr>
|
| 257 |
+
<td>DINOv3 ConvNeXt Tiny</td>
|
| 258 |
+
<td align="right">86.6</td>
|
| 259 |
+
<td align="right">87.7</td>
|
| 260 |
+
<td align="right">73.7</td>
|
| 261 |
+
<td align="right">74.1</td>
|
| 262 |
+
<td align="right">52.6</td>
|
| 263 |
+
<td align="right">58.7</td>
|
| 264 |
+
<td align="right">42.7</td>
|
| 265 |
+
<td align="right">0.448</td>
|
| 266 |
+
</tr>
|
| 267 |
+
<tr>
|
| 268 |
+
<td>DINOv3 ConvNeXt Small</td>
|
| 269 |
+
<td align="right">87.9</td>
|
| 270 |
+
<td align="right">88.7</td>
|
| 271 |
+
<td align="right">73.7</td>
|
| 272 |
+
<td align="right">74.1</td>
|
| 273 |
+
<td align="right">52.6</td>
|
| 274 |
+
<td align="right">58.7</td>
|
| 275 |
+
<td align="right">44.8</td>
|
| 276 |
+
<td align="right">0.432</td>
|
| 277 |
+
</tr>
|
| 278 |
+
<tr>
|
| 279 |
+
<td>DINOv3 ConvNeXt Base</td>
|
| 280 |
+
<td align="right">88.5</td>
|
| 281 |
+
<td align="right">89.2</td>
|
| 282 |
+
<td align="right">77.2</td>
|
| 283 |
+
<td align="right">78.2</td>
|
| 284 |
+
<td align="right">56.2</td>
|
| 285 |
+
<td align="right">61.3</td>
|
| 286 |
+
<td align="right">46.3</td>
|
| 287 |
+
<td align="right">0.420</td>
|
| 288 |
+
</tr>
|
| 289 |
+
<tr>
|
| 290 |
+
<td>DINOv3 ConvNeXt Large</td>
|
| 291 |
+
<td align="right">88.9</td>
|
| 292 |
+
<td align="right">89.4</td>
|
| 293 |
+
<td align="right">81.3</td>
|
| 294 |
+
<td align="right">82.4</td>
|
| 295 |
+
<td align="right">59.3</td>
|
| 296 |
+
<td align="right">65.2</td>
|
| 297 |
+
<td align="right">47.8</td>
|
| 298 |
+
<td align="right">0.403</td>
|
| 299 |
+
</tr>
|
| 300 |
+
</table>
|
| 301 |
+
|
| 302 |
+
*Results for ViT backbones pretrained (or distilled) on satellite (SAT-493M)*
|
| 303 |
+
|
| 304 |
+
<table>
|
| 305 |
+
<tr>
|
| 306 |
+
<th></th>
|
| 307 |
+
<th colspan="7">(GEO-Bench) Classification</th>
|
| 308 |
+
</tr>
|
| 309 |
+
<tr>
|
| 310 |
+
<th>Model</ht>
|
| 311 |
+
<th>m-BEnet</th>
|
| 312 |
+
<th>m-brick-kiln
|
| 313 |
+
<th>m-eurosat</th>
|
| 314 |
+
<th>m-forestnet</th>
|
| 315 |
+
<th>m-pv4ger</th>
|
| 316 |
+
<th>m-so2sat</th>
|
| 317 |
+
<th>mean</th>
|
| 318 |
+
</tr>
|
| 319 |
+
<tr>
|
| 320 |
+
<td>DINOv3 ViT-L/16</td>
|
| 321 |
+
<td>73.0</td>
|
| 322 |
+
<td>96.5</td>
|
| 323 |
+
<td>94.1</td>
|
| 324 |
+
<td>60.6</td>
|
| 325 |
+
<td>96.0</td>
|
| 326 |
+
<td>57.4</td>
|
| 327 |
+
<td>79.6</td>
|
| 328 |
+
</tr>
|
| 329 |
+
<tr>
|
| 330 |
+
<td>DINOv3 ViT-7B/16</td>
|
| 331 |
+
<td>74.0</td>
|
| 332 |
+
<td>97.2</td>
|
| 333 |
+
<td>94.8</td>
|
| 334 |
+
<td>62.3</td>
|
| 335 |
+
<td>96.1</td>
|
| 336 |
+
<td>62.1</td>
|
| 337 |
+
<td>81.1</td>
|
| 338 |
+
</tr>
|
| 339 |
+
<tr>
|
| 340 |
+
<th></th>
|
| 341 |
+
<th colspan="7">(GEO-Bench) Segmentation</th>
|
| 342 |
+
</tr>
|
| 343 |
+
<tr>
|
| 344 |
+
<th>Model</th>
|
| 345 |
+
<th>m-cashew</th>
|
| 346 |
+
<th>m-chesapeake</th>
|
| 347 |
+
<th>m-NeonTree</th>
|
| 348 |
+
<th>m-nz-cattle</th>
|
| 349 |
+
<th>m-pv4ger-seg</th>
|
| 350 |
+
<th>m-SA-crop</th>
|
| 351 |
+
<th>mean</th>
|
| 352 |
+
</tr>
|
| 353 |
+
<tr>
|
| 354 |
+
<td>DINOv3 ViT-L/16</td>
|
| 355 |
+
<td>94.2</td>
|
| 356 |
+
<td>75.6</td>
|
| 357 |
+
<td>61.8</td>
|
| 358 |
+
<td>83.7</td>
|
| 359 |
+
<td>95.2</td>
|
| 360 |
+
<td>36.8</td>
|
| 361 |
+
<td>74.5</td>
|
| 362 |
+
</tr>
|
| 363 |
+
<tr>
|
| 364 |
+
<td>DINOv3 ViT-7B/16</td>
|
| 365 |
+
<td>94.1</td>
|
| 366 |
+
<td>76.6</td>
|
| 367 |
+
<td>62.6</td>
|
| 368 |
+
<td>83.4</td>
|
| 369 |
+
<td>95.5</td>
|
| 370 |
+
<td>37.6</td>
|
| 371 |
+
<td>75.0</td>
|
| 372 |
+
</tr>
|
| 373 |
+
</table>
|
| 374 |
+
|
| 375 |
+
|
| 376 |
+
## Environmental Impact
|
| 377 |
+
|
| 378 |
+
- **Hardware Type:** Nvidia H100
|
| 379 |
+
- **Hours used:** 61,440 hours for ViT-7B model training
|
| 380 |
+
- **Cloud Provider:** Private infrastructure
|
| 381 |
+
- **Compute Region:** USA
|
| 382 |
+
- **Carbon Emitted:** 18t CO2eq
|
| 383 |
+
|
| 384 |
+
## Technical Specifications
|
| 385 |
+
|
| 386 |
+
### Model Architecture and Objective
|
| 387 |
+
|
| 388 |
+
Vision Transformer models:
|
| 389 |
+
|
| 390 |
+
- ViT-S (21M parameters): patch size 16, embedding dimension 384, 4 register tokens, 6 heads, MLP FFN, RoPE
|
| 391 |
+
- ViT-S+ (29M parameters): patch size 16, embedding dimension 384, 4 register tokens, 6 heads, SwiGLU FFN, RoPE
|
| 392 |
+
- ViT-B (86M parameters): patch size 16, embedding dimension 768, 4 register tokens, 12 heads, MLP FFN, RoPE
|
| 393 |
+
- ViT-L (300M parameters): patch size 16, embedding dimension 1024, 4 register tokens, 16 heads, MLP FFN, RoPE
|
| 394 |
+
- ViT-H+ (840M parameters): patch size 16, embedding dimension 1280, 4 register tokens, 20 heads, SwiGLU FFN, RoPE
|
| 395 |
+
- ViT-7B (6716M parameters): patch size 16, embedding dimension 4096, 4 register tokens, 32 heads, SwiGLU FFN, RoPE
|
| 396 |
+
|
| 397 |
+
ConvNeXt models:
|
| 398 |
+
|
| 399 |
+
- ConvNeXt Tiny (29M parameters)
|
| 400 |
+
- ConvNeXt Small (50M parameters)
|
| 401 |
+
- ConvNeXt Base (89M parameters)
|
| 402 |
+
- ConvNeXt Large (198M parameters)
|
| 403 |
+
|
| 404 |
+
### Compute Infrastructure
|
| 405 |
+
|
| 406 |
+
#### Hardware
|
| 407 |
+
|
| 408 |
+
Nvidia H100 GPUs
|
| 409 |
+
|
| 410 |
+
#### Software
|
| 411 |
+
|
| 412 |
+
PyTorch 2.7
|
| 413 |
+
|
| 414 |
+
## More Information
|
| 415 |
+
|
| 416 |
+
See the [blog post](https://ai.meta.com/blog/dinov3-self-supervised-vision-model/) and the associated [website](https://ai.meta.com/dinov3/).
|
| 417 |
+
|
| 418 |
+
## Citation
|
| 419 |
+
|
| 420 |
+
**BibTeX**
|
| 421 |
+
|
| 422 |
+
```
|
| 423 |
+
@misc{simeoni2025dinov3,
|
| 424 |
+
title={{DINOv3}},
|
| 425 |
+
author={Sim{\'e}oni, Oriane and Vo, Huy V. and Seitzer, Maximilian and Baldassarre, Federico and Oquab, Maxime and Jose, Cijo and Khalidov, Vasil and Szafraniec, Marc and Yi, Seungeun and Ramamonjisoa, Micha{\"e}l and Massa, Francisco and Haziza, Daniel and Wehrstedt, Luca and Wang, Jianyuan and Darcet, Timoth{\'e}e and Moutakanni, Th{\'e}o and Sentana, Leonel and Roberts, Claire and Vedaldi, Andrea and Tolan, Jamie and Brandt, John and Couprie, Camille and Mairal, Julien and J{\'e}gou, Herv{\'e} and Labatut, Patrick and Bojanowski, Piotr},
|
| 426 |
+
year={2025},
|
| 427 |
+
eprint={2508.10104},
|
| 428 |
+
archivePrefix={arXiv},
|
| 429 |
+
primaryClass={cs.CV},
|
| 430 |
+
url={https://arxiv.org/abs/2508.10104},
|
| 431 |
+
}
|
| 432 |
+
```
|
InfiniDepth/model/block/torchhub/dinov3/README.md
ADDED
|
@@ -0,0 +1,734 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
🆕 [2025-08-14] :fire: DINOv3 backbones are now available in [Hugging Face Hub](https://huggingface.co/collections/facebook/dinov3-68924841bd6b561778e31009) and [supported](https://huggingface.co/docs/transformers/model_doc/dinov3) by the Hugging Face [Transformers](https://huggingface.co/docs/transformers/index) library
|
| 2 |
+
|
| 3 |
+
# DINOv3 🦖🦖🦖
|
| 4 |
+
|
| 5 |
+
**[Meta AI Research, FAIR](https://ai.meta.com/research/)**
|
| 6 |
+
|
| 7 |
+
Oriane Siméoni, Huy V. Vo, Maximilian Seitzer, Federico Baldassarre, Maxime Oquab, <br/>
|
| 8 |
+
Cijo Jose, Vasil Khalidov, Marc Szafraniec, Seungeun Yi, Michaël Ramamonjisoa, <br/>
|
| 9 |
+
Francisco Massa, Daniel Haziza, Luca Wehrstedt, Jianyuan Wang, <br/>
|
| 10 |
+
Timothée Darcet, Théo Moutakanni, Leonel Sentana, Claire Roberts, <br/>
|
| 11 |
+
Andrea Vedaldi, Jamie Tolan, John Brandt, Camille Couprie, <br/>
|
| 12 |
+
Julien Mairal, Hervé Jégou, Patrick Labatut, Piotr Bojanowski
|
| 13 |
+
|
| 14 |
+
[ :scroll: [`Paper`](https://arxiv.org/abs/2508.10104)] [ :newspaper: [`Blog`](https://ai.meta.com/blog/dinov3-self-supervised-vision-model/)] [ :globe_with_meridians: [`Website`](https://ai.meta.com/dinov3/)] [ :book: [`BibTeX`](#citing-dinov3)]
|
| 15 |
+
|
| 16 |
+
Reference PyTorch implementation and models for DINOv3. For details, see the **[DINOv3](https://arxiv.org/abs/2508.10104)** paper.
|
| 17 |
+
|
| 18 |
+
## Overview
|
| 19 |
+
|
| 20 |
+
<div align="center">
|
| 21 |
+
<img width="1364" height="1024" alt="market" src="https://github.com/user-attachments/assets/1411f491-988e-49cb-95ae-d03fe6e3c268" />
|
| 22 |
+
|
| 23 |
+
<i></em><b>High-resolution dense features.</b><br/>We visualize the cosine similarity maps obtained with DINOv3 output features<br/> between the patches marked with a red cross and all other patches.</i>
|
| 24 |
+
</div>
|
| 25 |
+
|
| 26 |
+
<br/>
|
| 27 |
+
|
| 28 |
+
An extended family of versatile vision foundation models producing high-quality dense features and achieving outstanding performance on various vision tasks including outperforming the specialized state of the art across a broad range of settings, without fine-tuning
|
| 29 |
+
|
| 30 |
+
## Pretrained models
|
| 31 |
+
|
| 32 |
+
:information_source: Please follow the link provided below to get access to all the model weights: once accepted, an e-mail will be sent with the complete list of URLs pointing to all the available model weights (both backbones and adapters). These URLs can then be used to either:
|
| 33 |
+
- download the model or adapter weights to a local filesystem and point `torch.hub.load()` to these local weights via the `weights` or `backbone_weights` parameters, or
|
| 34 |
+
- directly invoke `torch.hub.load()` to download and load a backbone or an adapter from its URL via also the `weights` or `backbone_weights` parameters.
|
| 35 |
+
|
| 36 |
+
See the example code snippets below.
|
| 37 |
+
|
| 38 |
+
:warning: Please use `wget` instead of a web browser to download the weights.
|
| 39 |
+
|
| 40 |
+
ViT models pretrained on web dataset (LVD-1689M):
|
| 41 |
+
<table style="margin: auto">
|
| 42 |
+
<thead>
|
| 43 |
+
<tr>
|
| 44 |
+
<th>Model</th>
|
| 45 |
+
<th>Parameters</th>
|
| 46 |
+
<th>Pretraining<br/>Dataset</th>
|
| 47 |
+
<th>Download</th>
|
| 48 |
+
</tr>
|
| 49 |
+
</thead>
|
| 50 |
+
<tbody>
|
| 51 |
+
<tr>
|
| 52 |
+
<td>ViT-S/16 distilled </td>
|
| 53 |
+
<td align="right">21M</td>
|
| 54 |
+
<td align="center">LVD-1689M</td>
|
| 55 |
+
<td align="center"><a href="https://ai.meta.com/resources/models-and-libraries/dinov3-downloads/">[link]</a></td>
|
| 56 |
+
</tr>
|
| 57 |
+
<tr>
|
| 58 |
+
<td>ViT-S+/16 distilled</td>
|
| 59 |
+
<td align="right">29M</td>
|
| 60 |
+
<td align="center">LVD-1689M</td>
|
| 61 |
+
<td align="center"><a href="https://ai.meta.com/resources/models-and-libraries/dinov3-downloads/">[link]</a></td>
|
| 62 |
+
</tr>
|
| 63 |
+
<tr>
|
| 64 |
+
<td>ViT-B/16 distilled</td>
|
| 65 |
+
<td align="right">86M</td>
|
| 66 |
+
<td align="center">LVD-1689M</td>
|
| 67 |
+
<td align="center"><a href="https://ai.meta.com/resources/models-and-libraries/dinov3-downloads/">[link]</a></td>
|
| 68 |
+
</tr>
|
| 69 |
+
<tr>
|
| 70 |
+
<td>ViT-L/16 distilled</td>
|
| 71 |
+
<td align="right">300M</td>
|
| 72 |
+
<td align="center">LVD-1689M</td>
|
| 73 |
+
<td align="center"><a href="https://ai.meta.com/resources/models-and-libraries/dinov3-downloads/">[link]</a></td>
|
| 74 |
+
</tr>
|
| 75 |
+
<tr>
|
| 76 |
+
<td>ViT-H+/16 distilled</td>
|
| 77 |
+
<td align="right">840M</td>
|
| 78 |
+
<td align="center">LVD-1689M</td>
|
| 79 |
+
<td align="center"><a href="https://ai.meta.com/resources/models-and-libraries/dinov3-downloads/">[link]</a></td>
|
| 80 |
+
</tr>
|
| 81 |
+
<tr>
|
| 82 |
+
<td>ViT-7B/16</td>
|
| 83 |
+
<td align="right">6,716M</td>
|
| 84 |
+
<td align="center">LVD-1689M</td>
|
| 85 |
+
<td align="center"><a href="https://ai.meta.com/resources/models-and-libraries/dinov3-downloads/">[link]</a></td>
|
| 86 |
+
</tr>
|
| 87 |
+
</tbody>
|
| 88 |
+
</table>
|
| 89 |
+
|
| 90 |
+
ConvNeXt models pretrained on web dataset (LVD-1689M):
|
| 91 |
+
<table style="margin: auto">
|
| 92 |
+
<thead>
|
| 93 |
+
<tr>
|
| 94 |
+
<th>Model</th>
|
| 95 |
+
<th>Parameters</th>
|
| 96 |
+
<th>Pretraining<br/>Dataset</th>
|
| 97 |
+
<th>Download</th>
|
| 98 |
+
</tr>
|
| 99 |
+
</thead>
|
| 100 |
+
<tbody>
|
| 101 |
+
<tr>
|
| 102 |
+
<td>ConvNeXt Tiny</td>
|
| 103 |
+
<td align="right">29M</td>
|
| 104 |
+
<td align="center">LVD-1689M</td>
|
| 105 |
+
<td align="center"><a href="https://ai.meta.com/resources/models-and-libraries/dinov3-downloads/">[link]</a></td>
|
| 106 |
+
</tr>
|
| 107 |
+
<tr>
|
| 108 |
+
<td>ConvNeXt Small</td>
|
| 109 |
+
<td align="right">50M</td>
|
| 110 |
+
<td align="center">LVD-1689M</td>
|
| 111 |
+
<td align="center"><a href="https://ai.meta.com/resources/models-and-libraries/dinov3-downloads/">[link]</a></td>
|
| 112 |
+
</tr>
|
| 113 |
+
<tr>
|
| 114 |
+
<td>ConvNeXt Base</td>
|
| 115 |
+
<td align="right">89M</td>
|
| 116 |
+
<td align="center">LVD-1689M</td>
|
| 117 |
+
<td align="center"><a href="https://ai.meta.com/resources/models-and-libraries/dinov3-downloads/">[link]</a></td>
|
| 118 |
+
</tr>
|
| 119 |
+
<tr>
|
| 120 |
+
<td>ConvNeXt Large</td>
|
| 121 |
+
<td align="right">198M</td>
|
| 122 |
+
<td align="center">LVD-1689M</td>
|
| 123 |
+
<td align="center"><a href="https://ai.meta.com/resources/models-and-libraries/dinov3-downloads/">[link]</a></td>
|
| 124 |
+
</tr>
|
| 125 |
+
</tbody>
|
| 126 |
+
</table>
|
| 127 |
+
|
| 128 |
+
ViT models pretrained on satellite dataset (SAT-493M):
|
| 129 |
+
<table style="margin: auto">
|
| 130 |
+
<thead>
|
| 131 |
+
<tr>
|
| 132 |
+
<th>Model</th>
|
| 133 |
+
<th>Parameters</th>
|
| 134 |
+
<th>Pretraining<br/>Dataset</th>
|
| 135 |
+
<th>Download</th>
|
| 136 |
+
</tr>
|
| 137 |
+
</thead>
|
| 138 |
+
<tbody>
|
| 139 |
+
<tr>
|
| 140 |
+
<td>ViT-L/16 distilled</td>
|
| 141 |
+
<td align="right">300M</td>
|
| 142 |
+
<td align="center">SAT-493M</td>
|
| 143 |
+
<td align="center"><a href="https://ai.meta.com/resources/models-and-libraries/dinov3-downloads/">[link]</a></td>
|
| 144 |
+
</tr>
|
| 145 |
+
<tr>
|
| 146 |
+
<td>ViT-7B/16</td>
|
| 147 |
+
<td align="right">6,716M</td>
|
| 148 |
+
<td align="center">SAT-493M</td>
|
| 149 |
+
<td align="center"><a href="https://ai.meta.com/resources/models-and-libraries/dinov3-downloads/">[link]</a></td>
|
| 150 |
+
</tr>
|
| 151 |
+
</tbody>
|
| 152 |
+
</table>
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
### Pretrained backbones (via PyTorch [Hub](https://docs.pytorch.org/docs/stable/hub.html))
|
| 156 |
+
|
| 157 |
+
Please follow the instructions [here](https://pytorch.org/get-started/locally/) to install PyTorch (the only required dependency for loading the model). Installing PyTorch with CUDA support is strongly recommended.
|
| 158 |
+
|
| 159 |
+
```python
|
| 160 |
+
import torch
|
| 161 |
+
|
| 162 |
+
REPO_DIR = <PATH/TO/A/LOCAL/DIRECTORY/WHERE/THE/DINOV3/REPO/WAS/CLONED>
|
| 163 |
+
|
| 164 |
+
# DINOv3 ViT models pretrained on web images
|
| 165 |
+
dinov3_vits16 = torch.hub.load(REPO_DIR, 'dinov3_vits16', source='local', weights=<CHECKPOINT/URL/OR/PATH>)
|
| 166 |
+
dinov3_vits16plus = torch.hub.load(REPO_DIR, 'dinov3_vits16plus', source='local', weights=<CHECKPOINT/URL/OR/PATH>)
|
| 167 |
+
dinov3_vitb16 = torch.hub.load(REPO_DIR, 'dinov3_vitb16', source='local', weights=<CHECKPOINT/URL/OR/PATH>)
|
| 168 |
+
dinov3_vitl16 = torch.hub.load(REPO_DIR, 'dinov3_vitl16', source='local', weights=<CHECKPOINT/URL/OR/PATH>)
|
| 169 |
+
dinov3_vith16plus = torch.hub.load(REPO_DIR, 'dinov3_vith16plus', source='local', weights=<CHECKPOINT/URL/OR/PATH>)
|
| 170 |
+
dinov3_vit7b16 = torch.hub.load(REPO_DIR, 'dinov3_vit7b16', source='local', weights=<CHECKPOINT/URL/OR/PATH>)
|
| 171 |
+
|
| 172 |
+
# DINOv3 ConvNeXt models pretrained on web images
|
| 173 |
+
dinov3_convnext_tiny = torch.hub.load(REPO_DIR, 'dinov3_convnext_tiny', source='local', weights=<CHECKPOINT/URL/OR/PATH>)
|
| 174 |
+
dinov3_convnext_small = torch.hub.load(REPO_DIR, 'dinov3_convnext_small', source='local', weights=<CHECKPOINT/URL/OR/PATH>)
|
| 175 |
+
dinov3_convnext_base = torch.hub.load(REPO_DIR, 'dinov3_convnext_base', source='local', weights=<CHECKPOINT/URL/OR/PATH>)
|
| 176 |
+
dinov3_convnext_large = torch.hub.load(REPO_DIR, 'dinov3_convnext_large', source='local', weights=<CHECKPOINT/URL/OR/PATH>)
|
| 177 |
+
|
| 178 |
+
# DINOv3 ViT models pretrained on satellite imagery
|
| 179 |
+
dinov3_vitl16 = torch.hub.load(REPO_DIR, 'dinov3_vitl16', source='local', weights=<CHECKPOINT/URL/OR/PATH>)
|
| 180 |
+
dinov3_vit7b16 = torch.hub.load(REPO_DIR, 'dinov3_vit7b16', source='local', weights=<CHECKPOINT/URL/OR/PATH>)
|
| 181 |
+
```
|
| 182 |
+
|
| 183 |
+
### Pretrained backbones (via Hugging Face [Transformers](https://huggingface.co/docs/transformers/))
|
| 184 |
+
|
| 185 |
+
All the backbones are available in the the [DINOv3](https://huggingface.co/collections/facebook/dinov3-68924841bd6b561778e31009) collection on Hugging Face Hub and supported via the Hugging Face [Transformers](https://huggingface.co/docs/transformers/index) library. Please refer to the corresponding documentation for usage, but below is a short example that demonstrates how to obtain an image embedding with either [Pipeline] or the [AutoModel] class.
|
| 186 |
+
|
| 187 |
+
```python
|
| 188 |
+
from transformers import pipeline
|
| 189 |
+
from transformers.image_utils import load_image
|
| 190 |
+
|
| 191 |
+
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"
|
| 192 |
+
image = load_image(url)
|
| 193 |
+
|
| 194 |
+
feature_extractor = pipeline(
|
| 195 |
+
model="facebook/dinov3-convnext-tiny-pretrain-lvd1689m",
|
| 196 |
+
task="image-feature-extraction",
|
| 197 |
+
)
|
| 198 |
+
features = feature_extractor(image)
|
| 199 |
+
```
|
| 200 |
+
|
| 201 |
+
```python
|
| 202 |
+
import torch
|
| 203 |
+
from transformers import AutoImageProcessor, AutoModel
|
| 204 |
+
from transformers.image_utils import load_image
|
| 205 |
+
|
| 206 |
+
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
| 207 |
+
image = load_image(url)
|
| 208 |
+
|
| 209 |
+
pretrained_model_name = "facebook/dinov3-convnext-tiny-pretrain-lvd1689m"
|
| 210 |
+
processor = AutoImageProcessor.from_pretrained(pretrained_model_name)
|
| 211 |
+
model = AutoModel.from_pretrained(
|
| 212 |
+
pretrained_model_name,
|
| 213 |
+
device_map="auto",
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
inputs = processor(images=image, return_tensors="pt").to(model.device)
|
| 217 |
+
with torch.inference_mode():
|
| 218 |
+
outputs = model(**inputs)
|
| 219 |
+
|
| 220 |
+
pooled_output = outputs.pooler_output
|
| 221 |
+
print("Pooled output shape:", pooled_output.shape)
|
| 222 |
+
```
|
| 223 |
+
|
| 224 |
+
where `model` and `pretrained_model_name` above can be one of:
|
| 225 |
+
- `facebook/dinov3-vits16-pretrain-lvd1689m`
|
| 226 |
+
- `facebook/dinov3-vits16plus-pretrain-lvd1689m`
|
| 227 |
+
- `facebook/dinov3-vitb16-pretrain-lvd1689m`
|
| 228 |
+
- `facebook/dinov3-vitl16-pretrain-lvd1689m`
|
| 229 |
+
- `facebook/dinov3-vith16plus-pretrain-lvd1689m`
|
| 230 |
+
- `facebook/dinov3-vit7b16-pretrain-lvd1689m`
|
| 231 |
+
- `facebook/dinov3-convnext-base-pretrain-lvd1689m`
|
| 232 |
+
- `facebook/dinov3-convnext-large-pretrain-lvd1689m`
|
| 233 |
+
- `facebook/dinov3-convnext-small-pretrain-lvd1689m`
|
| 234 |
+
- `facebook/dinov3-convnext-tiny-pretrain-lvd1689m`
|
| 235 |
+
- `facebook/dinov3-vitl16-pretrain-sat493m`
|
| 236 |
+
- `facebook/dinov3-vit7b16-pretrain-sat493m`
|
| 237 |
+
|
| 238 |
+
### Image transforms
|
| 239 |
+
|
| 240 |
+
For models using the LVD-1689M weights (pretrained on web images), please use the following transform (standard ImageNet evaluation transform):
|
| 241 |
+
|
| 242 |
+
```python
|
| 243 |
+
import torchvision
|
| 244 |
+
|
| 245 |
+
def make_transform(resize_size: int = 224):
|
| 246 |
+
to_tensor = transforms.ToTensor()
|
| 247 |
+
resize = transforms.Resize((resize_size, resize_size), antialias=True)
|
| 248 |
+
normalize = transforms.Normalize(
|
| 249 |
+
mean=(0.485, 0.456, 0.406),
|
| 250 |
+
std=(0.229, 0.224, 0.225),
|
| 251 |
+
)
|
| 252 |
+
return transforms.Compose([to_tensor, resize, normalize])
|
| 253 |
+
```
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
For models using the SAT-493M weights (pretrained on satellite imagery), please use the following transform:
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
```python
|
| 260 |
+
import torchvision
|
| 261 |
+
|
| 262 |
+
def make_transform(resize_size: int = 224):
|
| 263 |
+
to_tensor = transforms.ToTensor()
|
| 264 |
+
resize = transforms.Resize((resize_size, resize_size), antialias=True)
|
| 265 |
+
normalize = transforms.Normalize(
|
| 266 |
+
mean=(0.430, 0.411, 0.296),
|
| 267 |
+
std=(0.213, 0.156, 0.143),
|
| 268 |
+
)
|
| 269 |
+
return transforms.Compose([to_tensor, resize, normalize])
|
| 270 |
+
```
|
| 271 |
+
|
| 272 |
+
### Pretrained heads - Image classification
|
| 273 |
+
|
| 274 |
+
<table style="margin: auto">
|
| 275 |
+
<thead>
|
| 276 |
+
<tr>
|
| 277 |
+
<th>Backbone</th>
|
| 278 |
+
<th>Pretraining<br/>Dataset</th>
|
| 279 |
+
<th>Head<br/>Dataset</th>
|
| 280 |
+
<th>Download</th>
|
| 281 |
+
</tr>
|
| 282 |
+
</thead>
|
| 283 |
+
<tbody>
|
| 284 |
+
<tr>
|
| 285 |
+
<td>ViT-7B/16</td>
|
| 286 |
+
<td align="center">LVD-1689M</td>
|
| 287 |
+
<td align="center">ImageNet</td>
|
| 288 |
+
<td align="center"><a href="https://ai.meta.com/resources/models-and-libraries/dinov3-downloads/">[link]</a></td>
|
| 289 |
+
</tr>
|
| 290 |
+
</tbody>
|
| 291 |
+
</table>
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
The (full) classifier models can be loaded via PyTorch Hub:
|
| 295 |
+
|
| 296 |
+
```python
|
| 297 |
+
import torch
|
| 298 |
+
|
| 299 |
+
# DINOv3
|
| 300 |
+
dinov3_vit7b16_lc = torch.hub.load(REPO_DIR, 'dinov3_vit7b16_lc', source="local", weights=<DEPTHER/CHECKPOINT/URL/OR/PATH>, backbone_weights=<BACKBONE/CHECKPOINT/URL/OR/PATH>)
|
| 301 |
+
|
| 302 |
+
```
|
| 303 |
+
|
| 304 |
+
### Pretrained heads - Depther trained on SYNTHMIX dataset
|
| 305 |
+
|
| 306 |
+
<table style="margin: auto">
|
| 307 |
+
<thead>
|
| 308 |
+
<tr>
|
| 309 |
+
<th>Backbone</th>
|
| 310 |
+
<th>Pretraining<br/>Dataset</th>
|
| 311 |
+
<th>Head<br/>Dataset</th>
|
| 312 |
+
<th>Download</th>
|
| 313 |
+
</tr>
|
| 314 |
+
</thead>
|
| 315 |
+
<tbody>
|
| 316 |
+
<tr>
|
| 317 |
+
<td>ViT-7B/16</td>
|
| 318 |
+
<td align="center">LVD-1689M</td>
|
| 319 |
+
<td align="center">SYNTHMIX</td>
|
| 320 |
+
<td align="center"><a href="https://ai.meta.com/resources/models-and-libraries/dinov3-downloads/">[link]</a></td>
|
| 321 |
+
</tr>
|
| 322 |
+
</tbody>
|
| 323 |
+
</table>
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
```python
|
| 327 |
+
depther = torch.hub.load(REPO_DIR, 'dinov3_vit7b16_dd', source="local", weights=<DEPTHER/CHECKPOINT/URL/OR/PATH>, backbone_weights=<BACKBONE/CHECKPOINT/URL/OR/PATH>)
|
| 328 |
+
```
|
| 329 |
+
|
| 330 |
+
Full example code of depther on an image
|
| 331 |
+
|
| 332 |
+
```python
|
| 333 |
+
from PIL import Image
|
| 334 |
+
import torch
|
| 335 |
+
from torchvision import transforms
|
| 336 |
+
import matplotlib.pyplot as plt
|
| 337 |
+
from matplotlib import colormaps
|
| 338 |
+
|
| 339 |
+
def get_img():
|
| 340 |
+
import requests
|
| 341 |
+
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
| 342 |
+
image = Image.open(requests.get(url, stream=True).raw).convert("RGB")
|
| 343 |
+
return image
|
| 344 |
+
|
| 345 |
+
def make_transform(resize_size: int | list[int] = 768):
|
| 346 |
+
to_tensor = transforms.ToTensor()
|
| 347 |
+
resize = transforms.Resize((resize_size, resize_size), antialias=True)
|
| 348 |
+
normalize = transforms.Normalize(
|
| 349 |
+
mean=(0.485, 0.456, 0.406),
|
| 350 |
+
std=(0.229, 0.224, 0.225),
|
| 351 |
+
)
|
| 352 |
+
return transforms.Compose([to_tensor, resize, normalize])
|
| 353 |
+
|
| 354 |
+
depther = torch.hub.load(REPO_DIR, 'dinov3_vit7b16_dd', source="local", weights=<DEPTHER/CHECKPOINT/URL/OR/PATH>, backbone_weights=<BACKBONE/CHECKPOINT/URL/OR/PATH>)
|
| 355 |
+
|
| 356 |
+
img_size = 1024
|
| 357 |
+
img = get_img()
|
| 358 |
+
transform = make_transform(img_size)
|
| 359 |
+
with torch.inference_mode():
|
| 360 |
+
with torch.autocast('cuda', dtype=torch.bfloat16):
|
| 361 |
+
batch_img = transform(img)[None]
|
| 362 |
+
batch_img = batch_img
|
| 363 |
+
depths = depther(batch_img)
|
| 364 |
+
|
| 365 |
+
plt.figure(figsize=(12, 6))
|
| 366 |
+
plt.subplot(121)
|
| 367 |
+
plt.imshow(img)
|
| 368 |
+
plt.axis("off")
|
| 369 |
+
plt.subplot(122)
|
| 370 |
+
plt.imshow(depths[0,0].cpu(), cmap=colormaps["Spectral"])
|
| 371 |
+
plt.axis("off")
|
| 372 |
+
|
| 373 |
+
```
|
| 374 |
+
|
| 375 |
+
### Pretrained heads - Detector trained on COCO2017 dataset
|
| 376 |
+
|
| 377 |
+
<table style="margin: auto">
|
| 378 |
+
<thead>
|
| 379 |
+
<tr>
|
| 380 |
+
<th>Backbone</th>
|
| 381 |
+
<th>Pretraining<br/>Dataset</th>
|
| 382 |
+
<th>Head<br/>Dataset</th>
|
| 383 |
+
<th>Download</th>
|
| 384 |
+
</tr>
|
| 385 |
+
</thead>
|
| 386 |
+
<tbody>
|
| 387 |
+
<tr>
|
| 388 |
+
<td>ViT-7B/16</td>
|
| 389 |
+
<td align="center">LVD-1689M</td>
|
| 390 |
+
<td align="center">COCO2017</td>
|
| 391 |
+
<td align="center"><a href="https://ai.meta.com/resources/models-and-libraries/dinov3-downloads/">[link]</a></td>
|
| 392 |
+
</tr>
|
| 393 |
+
</tbody>
|
| 394 |
+
</table>
|
| 395 |
+
|
| 396 |
+
|
| 397 |
+
```python
|
| 398 |
+
detector = torch.hub.load(REPO_DIR, 'dinov3_vit7b16_de', source="local", weights=<DETECTOR/CHECKPOINT/URL/OR/PATH>, backbone_weights=<BACKBONE/CHECKPOINT/URL/OR/PATH>)
|
| 399 |
+
```
|
| 400 |
+
|
| 401 |
+
### Pretrained heads - Segmentor trained on ADE20K dataset
|
| 402 |
+
|
| 403 |
+
<table style="margin: auto">
|
| 404 |
+
<thead>
|
| 405 |
+
<tr>
|
| 406 |
+
<th>Backbone</th>
|
| 407 |
+
<th>Pretraining<br/>Dataset</th>
|
| 408 |
+
<th>Head<br/>Dataset</th>
|
| 409 |
+
<th>Download</th>
|
| 410 |
+
</tr>
|
| 411 |
+
</thead>
|
| 412 |
+
<tbody>
|
| 413 |
+
<tr>
|
| 414 |
+
<td>ViT-7B/16</td>
|
| 415 |
+
<td align="center">LVD-1689M</td>
|
| 416 |
+
<td align="center">ADE20K</td>
|
| 417 |
+
<td align="center"><a href="https://ai.meta.com/resources/models-and-libraries/dinov3-downloads/">[link]</a></td>
|
| 418 |
+
</tr>
|
| 419 |
+
</tbody>
|
| 420 |
+
</table>
|
| 421 |
+
|
| 422 |
+
```python
|
| 423 |
+
segmentor = torch.hub.load(REPO_DIR, 'dinov3_vit7b16_ms', source="local", weights=<SEGMENTOR/CHECKPOINT/URL/OR/PATH>, backbone_weights=<BACKBONE/CHECKPOINT/URL/OR/PATH>)
|
| 424 |
+
```
|
| 425 |
+
|
| 426 |
+
Full example code of segmentator on an image
|
| 427 |
+
|
| 428 |
+
```python
|
| 429 |
+
import sys
|
| 430 |
+
sys.path.append(REPO_DIR)
|
| 431 |
+
|
| 432 |
+
from PIL import Image
|
| 433 |
+
import torch
|
| 434 |
+
from torchvision import transforms
|
| 435 |
+
import matplotlib.pyplot as plt
|
| 436 |
+
from matplotlib import colormaps
|
| 437 |
+
from functools import partial
|
| 438 |
+
from dinov3.eval.segmentation.inference import make_inference
|
| 439 |
+
|
| 440 |
+
|
| 441 |
+
def get_img():
|
| 442 |
+
import requests
|
| 443 |
+
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
| 444 |
+
image = Image.open(requests.get(url, stream=True).raw).convert("RGB")
|
| 445 |
+
return image
|
| 446 |
+
|
| 447 |
+
def make_transform(resize_size: int | list[int] = 768):
|
| 448 |
+
to_tensor = transforms.ToTensor()
|
| 449 |
+
resize = transforms.Resize((resize_size, resize_size), antialias=True)
|
| 450 |
+
normalize = transforms.Normalize(
|
| 451 |
+
mean=(0.485, 0.456, 0.406),
|
| 452 |
+
std=(0.229, 0.224, 0.225),
|
| 453 |
+
)
|
| 454 |
+
return transforms.Compose([to_tensor, resize, normalize])
|
| 455 |
+
|
| 456 |
+
segmentor = torch.hub.load(REPO_DIR, 'dinov3_vit7b16_ms', source="local", weights=<SEGMENTOR/CHECKPOINT/URL/OR/PATH>, backbone_weights=<BACKBONE/CHECKPOINT/URL/OR/PATH>)
|
| 457 |
+
|
| 458 |
+
img_size = 896
|
| 459 |
+
img = get_img()
|
| 460 |
+
transform = make_transform(img_size)
|
| 461 |
+
with torch.inference_mode():
|
| 462 |
+
with torch.autocast('cuda', dtype=torch.bfloat16):
|
| 463 |
+
batch_img = transform(img)[None]
|
| 464 |
+
pred_vit7b = segmentor(batch_img) # raw predictions
|
| 465 |
+
# actual segmentation map
|
| 466 |
+
segmentation_map_vit7b = make_inference(
|
| 467 |
+
batch_img,
|
| 468 |
+
segmentor,
|
| 469 |
+
inference_mode="slide",
|
| 470 |
+
decoder_head_type="m2f",
|
| 471 |
+
rescale_to=(img.size[-1], img.size[-2]),
|
| 472 |
+
n_output_channels=150,
|
| 473 |
+
crop_size=(img_size, img_size),
|
| 474 |
+
stride=(img_size, img_size),
|
| 475 |
+
output_activation=partial(torch.nn.functional.softmax, dim=1),
|
| 476 |
+
).argmax(dim=1, keepdim=True)
|
| 477 |
+
plt.figure(figsize=(12, 6))
|
| 478 |
+
plt.subplot(121)
|
| 479 |
+
plt.imshow(img)
|
| 480 |
+
plt.axis("off")
|
| 481 |
+
plt.subplot(122)
|
| 482 |
+
plt.imshow(segmentation_map_vit7b[0,0].cpu(), cmap=colormaps["Spectral"])
|
| 483 |
+
plt.axis("off")
|
| 484 |
+
```
|
| 485 |
+
|
| 486 |
+
|
| 487 |
+
|
| 488 |
+
|
| 489 |
+
### Pretrained heads - Zero-shot tasks with `dino.txt`
|
| 490 |
+
|
| 491 |
+
<table style="margin: auto">
|
| 492 |
+
<thead>
|
| 493 |
+
<tr>
|
| 494 |
+
<th rowspan="2">Backbone</th>
|
| 495 |
+
<th>Download</th>
|
| 496 |
+
</tr>
|
| 497 |
+
</thead>
|
| 498 |
+
<tbody>
|
| 499 |
+
<tr>
|
| 500 |
+
<td>ViT-L/16 distilled</td>
|
| 501 |
+
<td align="center">
|
| 502 |
+
<a href="https://ai.meta.com/resources/models-and-libraries/dinov3-downloads/">[link]</a>,
|
| 503 |
+
<a href="https://dl.fbaipublicfiles.com/dinov3/thirdparty/bpe_simple_vocab_16e6.txt.gz">vocabulary</a>,
|
| 504 |
+
<a href="https://dl.fbaipublicfiles.com/dinov2/thirdparty/LICENSE">vocabulary license</a>
|
| 505 |
+
</td>
|
| 506 |
+
</tr>
|
| 507 |
+
</tbody>
|
| 508 |
+
</table>
|
| 509 |
+
|
| 510 |
+
The (full) dino.txt model can be loaded via PyTorch Hub:
|
| 511 |
+
|
| 512 |
+
```python
|
| 513 |
+
import torch
|
| 514 |
+
# DINOv3
|
| 515 |
+
dinov3_vitl16_dinotxt_tet1280d20h24l, tokenizer = torch.hub.load(REPO_DIR, 'dinov3_vitl16_dinotxt_tet1280d20h24l', weights=<SEGMENTOR/CHECKPOINT/URL/OR/PATH>, backbone_weights=<BACKBONE/CHECKPOINT/URL/OR/PATH>)
|
| 516 |
+
```
|
| 517 |
+
|
| 518 |
+
|
| 519 |
+
## Installation
|
| 520 |
+
|
| 521 |
+
The training and evaluation code requires PyTorch version >= 2.7.1 as well as a few other 3rd party packages. Note that the code has only been tested with the specified versions and also expects a Linux environment. To setup all the required dependencies for training and evaluation, please follow the instructions below:
|
| 522 |
+
|
| 523 |
+
*[micromamba](https://mamba.readthedocs.io/en/latest/user_guide/micromamba.html)* **(Recommended)** - Clone the repository and then create and activate a `dinov3` conda environment using the provided environment definition:
|
| 524 |
+
|
| 525 |
+
```shell
|
| 526 |
+
micromamba env create -f conda.yaml
|
| 527 |
+
micromamba activate dinov3
|
| 528 |
+
```
|
| 529 |
+
|
| 530 |
+
## Getting started
|
| 531 |
+
|
| 532 |
+
Several notebooks are provided to get started applying DINOv3:
|
| 533 |
+
- [PCA of patch features](notebooks/pca.ipynb): display the PCA of DINOv3 patch features on a foreground object (rainbow visualizations from the paper) [[Run in Google Colab]](https://colab.research.google.com/github/facebookresearch/dinov3/blob/main/notebooks/pca.ipynb)
|
| 534 |
+
- [Foreground segmentation](notebooks/foreground_segmentation.ipynb): train a linear foreground segmentation model based on DINOv3 features [[Run in Google Colab]](https://colab.research.google.com/github/facebookresearch/dinov3/blob/main/notebooks/foreground_segmentation.ipynb)
|
| 535 |
+
- [Dense and sparse matching](notebooks/dense_sparse_matching.ipynb): match patches from objects on two different images based on DINOv3 features [[Run in Google Colab]](https://colab.research.google.com/github/facebookresearch/dinov3/blob/main/notebooks/dense_sparse_matching.ipynb)
|
| 536 |
+
- [Segmentation tracking](notebooks/segmentation_tracking.ipynb): video segmentation tracking using a non-parametric method based on DINOv3 features [[Run in Google Colab]](https://colab.research.google.com/github/facebookresearch/dinov3/blob/main/notebooks/segmentation_tracking.ipynb)
|
| 537 |
+
|
| 538 |
+
## Data preparation
|
| 539 |
+
|
| 540 |
+
### ImageNet-1k
|
| 541 |
+
|
| 542 |
+
The root directory of the dataset should hold the following contents:
|
| 543 |
+
|
| 544 |
+
- `<ROOT>/test/ILSVRC2012_test_00000001.JPEG`
|
| 545 |
+
- `<ROOT>/test/[..]`
|
| 546 |
+
- `<ROOT>/test/ILSVRC2012_test_00100000.JPEG`
|
| 547 |
+
- `<ROOT>/train/n01440764/n01440764_10026.JPEG`
|
| 548 |
+
- `<ROOT>/train/[...]`
|
| 549 |
+
- `<ROOT>/train/n15075141/n15075141_9993.JPEG`
|
| 550 |
+
- `<ROOT>/val/n01440764/ILSVRC2012_val_00000293.JPEG`
|
| 551 |
+
- `<ROOT>/val/[...]`
|
| 552 |
+
- `<ROOT>/val/n15075141/ILSVRC2012_val_00049174.JPEG`
|
| 553 |
+
- `<ROOT>/labels.txt`
|
| 554 |
+
|
| 555 |
+
The provided dataset implementation expects a few additional metadata files to be present under the extra directory:
|
| 556 |
+
|
| 557 |
+
- `<EXTRA>/class-ids-TRAIN.npy`
|
| 558 |
+
- `<EXTRA>/class-ids-VAL.npy`
|
| 559 |
+
- `<EXTRA>/class-names-TRAIN.npy`
|
| 560 |
+
- `<EXTRA>/class-names-VAL.npy`
|
| 561 |
+
- `<EXTRA>/entries-TEST.npy`
|
| 562 |
+
- `<EXTRA>/entries-TRAIN.npy`
|
| 563 |
+
- `<EXTRA>/entries-VAL.npy`
|
| 564 |
+
|
| 565 |
+
These metadata files can be generated (once) with the following lines of Python code:
|
| 566 |
+
|
| 567 |
+
```python
|
| 568 |
+
from dinov3.data.datasets import ImageNet
|
| 569 |
+
|
| 570 |
+
for split in ImageNet.Split:
|
| 571 |
+
dataset = ImageNet(split=split, root="<ROOT>", extra="<EXTRA>")
|
| 572 |
+
dataset.dump_extra()
|
| 573 |
+
```
|
| 574 |
+
|
| 575 |
+
Note that the root and extra directories do not have to be distinct directories.
|
| 576 |
+
|
| 577 |
+
### ImageNet-22k
|
| 578 |
+
|
| 579 |
+
Please adapt the [dataset class](dinov3/data/datasets/image_net_22k.py) to match your local setup.
|
| 580 |
+
|
| 581 |
+
<br />
|
| 582 |
+
|
| 583 |
+
:warning: To execute the commands provided in the next sections for training and evaluation, the `dinov3` package should be included in the Python module search path, i.e. simply prefix the command to run with `PYTHONPATH=.`.
|
| 584 |
+
|
| 585 |
+
## Training
|
| 586 |
+
|
| 587 |
+
### Fast setup: training DINOv3 ViT-L/16 on ImageNet-1k
|
| 588 |
+
|
| 589 |
+
Run DINOv3 pre-training on 4 H100-80GB nodes (32 GPUs) in a SLURM cluster environment with submitit:
|
| 590 |
+
|
| 591 |
+
```shell
|
| 592 |
+
PYTHONPATH=${PWD} python -m dinov3.run.submit dinov3/train/train.py \
|
| 593 |
+
--nodes 4 \
|
| 594 |
+
--config-file dinov3/configs/train/vitl_im1k_lin834.yaml \
|
| 595 |
+
--output-dir <PATH/TO/OUTPUT/DIR> \
|
| 596 |
+
train.dataset_path=ImageNet22k:root=<PATH/TO/DATASET>:extra=<PATH/TO/DATASET>
|
| 597 |
+
```
|
| 598 |
+
Training time is approximately 14 hours and the resulting checkpoint should reach 82.0% on k-NN eval and 83.5% on linear eval.
|
| 599 |
+
|
| 600 |
+
The training code saves the weights of the teacher in the eval folder every 12500 iterations for evaluation.
|
| 601 |
+
|
| 602 |
+
### Exact DINOv3 setup: training DINOv3 ViT-7B/16
|
| 603 |
+
|
| 604 |
+
DINOv3 ViT-7B/16 is trained on a private dataset. The training involves 3 stages:
|
| 605 |
+
- Pretraining
|
| 606 |
+
- Gram anchoring
|
| 607 |
+
- High resolution adaptation
|
| 608 |
+
|
| 609 |
+
#### Pretraining
|
| 610 |
+
|
| 611 |
+
Launch DINOV3 ViT-7B/16 pretraining on 32 nodes (256 GPUs) in a SLURM cluster environment with submitit.
|
| 612 |
+
|
| 613 |
+
```shell
|
| 614 |
+
PYTHONPATH=${PWD} python -m dinov3.run.submit dinov3/train/train.py \
|
| 615 |
+
--nodes 32 \
|
| 616 |
+
--config-file dinov3/configs/train/dinov3_vit7b16_pretrain.yaml \
|
| 617 |
+
--output-dir <PATH/TO/OUTPUT/DIR> \
|
| 618 |
+
train.dataset_path=<DATASET>:root=<PATH/TO/DATASET>:extra=<PATH/TO/DATASET>
|
| 619 |
+
```
|
| 620 |
+
|
| 621 |
+
#### Gram anchoring
|
| 622 |
+
|
| 623 |
+
```shell
|
| 624 |
+
PYTHONPATH=${PWD} python -m dinov3.run.submit dinov3/train/train.py \
|
| 625 |
+
--nodes 32 \
|
| 626 |
+
--config-file dinov3/configs/train/dinov3_vit7b16_gram_anchor.yaml \
|
| 627 |
+
--output-dir <PATH/TO/OUTPUT/DIR> \
|
| 628 |
+
train.dataset_path=<DATASET>:root=<PATH/TO/DATASET>:extra=<PATH/TO/DATASET> \
|
| 629 |
+
gram.ckpt=<PATH/TO/GRAM_TEACHER_FROM_PREVIOUS_STEP>
|
| 630 |
+
```
|
| 631 |
+
|
| 632 |
+
#### High-resolution adaptation
|
| 633 |
+
|
| 634 |
+
|
| 635 |
+
```shell
|
| 636 |
+
PYTHONPATH=${PWD} python -m dinov3.run.submit dinov3/train/train.py \
|
| 637 |
+
--nodes 32 \
|
| 638 |
+
--config-file dinov3/configs/train/dinov3_vit7b16_high_res_adapt.yaml \
|
| 639 |
+
--output-dir <PATH/TO/OUTPUT/DIR> \
|
| 640 |
+
train.dataset_path=<DATASET>:root=<PATH/TO/DATASET>:extra=<PATH/TO/DATASET> \
|
| 641 |
+
gram.ckpt=<PATH/TO/TEACHER_FROM_GRAM> \
|
| 642 |
+
student.resume_from_teacher_chkpt=<PATH/TO/TEACHER_FROM_GRAM>
|
| 643 |
+
```
|
| 644 |
+
|
| 645 |
+
## Multi-distillation
|
| 646 |
+
|
| 647 |
+
### Test setup:
|
| 648 |
+
|
| 649 |
+
```shell
|
| 650 |
+
PYTHONPATH=${PWD} python -m dinov3.run.submit dinov3/train/train.py \
|
| 651 |
+
--nodes 1 \
|
| 652 |
+
--config-file dinov3/configs/train/multi_distillation_test.yaml \
|
| 653 |
+
--output-dir <PATH/TO/OUTPUT/DIR> \
|
| 654 |
+
--multi-distillation \
|
| 655 |
+
train.dataset_path=<DATASET>:root=<PATH/TO/DATASET>:extra=<PATH/TO/DATASET>
|
| 656 |
+
```
|
| 657 |
+
|
| 658 |
+
## Evaluation
|
| 659 |
+
|
| 660 |
+
The training code regularly saves the teacher weights. In order to evaluate the model, run the following evaluation on a single node:
|
| 661 |
+
|
| 662 |
+
|
| 663 |
+
### Logistic regression classification on ImageNet-1k
|
| 664 |
+
|
| 665 |
+
```shell
|
| 666 |
+
PYTHONPATH=${PWD} python -m dinov3.run.submit dinov3/eval/log_regression.py \
|
| 667 |
+
model.config_file=<PATH/TO/OUTPUT/DIR>/config.yaml \
|
| 668 |
+
model.pretrained_weights=<PATH/TO/OUTPUT/DIR>/teacher_checkpoint.pth \
|
| 669 |
+
output_dir=<PATH/TO/OUTPUT/DIR> \
|
| 670 |
+
train.dataset=ImageNet:split=TRAIN:root=<PATH/TO/DATASET>:extra=<PATH/TO/DATASET> \
|
| 671 |
+
eval.test_dataset=ImageNet:split=VAL:root=<PATH/TO/DATASET>:extra=<PATH/TO/DATASET>
|
| 672 |
+
```
|
| 673 |
+
|
| 674 |
+
### k-NN classification on ImageNet-1k
|
| 675 |
+
|
| 676 |
+
```shell
|
| 677 |
+
PYTHONPATH=${PWD} python -m dinov3.run.submit dinov3/eval/knn.py \
|
| 678 |
+
model.config_file=<PATH/TO/OUTPUT/DIR>/config.yaml \
|
| 679 |
+
model.pretrained_weights=<PATH/TO/OUTPUT/DIR>/teacher_checkpoint.pth \
|
| 680 |
+
output_dir=<PATH/TO/OUTPUT/DIR> \
|
| 681 |
+
train.dataset=ImageNet:split=TRAIN:root=<PATH/TO/DATASET>:extra=<PATH/TO/DATASET> \
|
| 682 |
+
eval.test_dataset=ImageNet:split=VAL:root=<PATH/TO/DATASET>:extra=<PATH/TO/DATASET>
|
| 683 |
+
```
|
| 684 |
+
|
| 685 |
+
### Linear classification with data augmentation on ImageNet-1k
|
| 686 |
+
|
| 687 |
+
```shell
|
| 688 |
+
PYTHONPATH=${PWD} python -m dinov3.run.submit dinov3/eval/linear.py \
|
| 689 |
+
model.config_file=<PATH/TO/OUTPUT/DIR>/config.yaml \
|
| 690 |
+
model.pretrained_weights=<PATH/TO/OUTPUT/DIR>/teacher_checkpoint.pth \
|
| 691 |
+
output_dir=<PATH/TO/OUTPUT/DIR> \
|
| 692 |
+
train.dataset=ImageNet:split=TRAIN:root=<PATH/TO/DATASET>:extra=<PATH/TO/DATASET> \
|
| 693 |
+
train.val_dataset=ImageNet:split=VAL:root=<PATH/TO/DATASET>:extra=<PATH/TO/DATASET>
|
| 694 |
+
```
|
| 695 |
+
|
| 696 |
+
|
| 697 |
+
### Text alignment on DINOv3 using dino.txt
|
| 698 |
+
|
| 699 |
+
Text alignment can be done following the method from `dino.txt` aka [DINOv2 Meets Text](https://arxiv.org/abs/2412.16334).
|
| 700 |
+
|
| 701 |
+
```shell
|
| 702 |
+
PYTHONPATH=${PWD} python -m dinov3.run.submit dinov3/eval/text/train_dinotxt.py \
|
| 703 |
+
--nodes 4 \
|
| 704 |
+
# An example config for text alignment is here: dinov3/eval/text/configs/dinov3_vitl_text.yaml \
|
| 705 |
+
trainer_config_file="<PATH/TO/DINOv3/TEXT/CONFIG>" \
|
| 706 |
+
output-dir=<PATH/TO/OUTPUT/DIR>
|
| 707 |
+
```
|
| 708 |
+
Launching the above trains text alignment on 4 nodes with 8 gpus each (32 gpus in total).
|
| 709 |
+
Please note that the text alignment model in the DINOv3 paper was trained on a private dataset and here we have given an example config in ```dinov3/eval/text/configs/dinov3_vitl_text.yaml``` using ```CocoCaptions``` dataset for illustration purposes.
|
| 710 |
+
Please adapt the provided ```CocoCaptions``` dataset class, the dataset can be found [here](https://www.kaggle.com/datasets/nikhil7280/coco-image-caption)
|
| 711 |
+
|
| 712 |
+
## License
|
| 713 |
+
|
| 714 |
+
DINOv3 code and model weights are released under the DINOv3 License. See [LICENSE.md](LICENSE.md) for additional details.
|
| 715 |
+
|
| 716 |
+
## Contributing
|
| 717 |
+
|
| 718 |
+
See [contributing](CONTRIBUTING.md) and the [code of conduct](CODE_OF_CONDUCT.md).
|
| 719 |
+
|
| 720 |
+
## Citing DINOv3
|
| 721 |
+
|
| 722 |
+
If you find this repository useful, please consider giving a star :star: and citation :t-rex::
|
| 723 |
+
|
| 724 |
+
```
|
| 725 |
+
@misc{simeoni2025dinov3,
|
| 726 |
+
title={{DINOv3}},
|
| 727 |
+
author={Sim{\'e}oni, Oriane and Vo, Huy V. and Seitzer, Maximilian and Baldassarre, Federico and Oquab, Maxime and Jose, Cijo and Khalidov, Vasil and Szafraniec, Marc and Yi, Seungeun and Ramamonjisoa, Micha{\"e}l and Massa, Francisco and Haziza, Daniel and Wehrstedt, Luca and Wang, Jianyuan and Darcet, Timoth{\'e}e and Moutakanni, Th{\'e}o and Sentana, Leonel and Roberts, Claire and Vedaldi, Andrea and Tolan, Jamie and Brandt, John and Couprie, Camille and Mairal, Julien and J{\'e}gou, Herv{\'e} and Labatut, Patrick and Bojanowski, Piotr},
|
| 728 |
+
year={2025},
|
| 729 |
+
eprint={2508.10104},
|
| 730 |
+
archivePrefix={arXiv},
|
| 731 |
+
primaryClass={cs.CV},
|
| 732 |
+
url={https://arxiv.org/abs/2508.10104},
|
| 733 |
+
}
|
| 734 |
+
```
|
InfiniDepth/model/block/torchhub/dinov3/conda.yaml
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: dinov3
|
| 2 |
+
channels:
|
| 3 |
+
- defaults
|
| 4 |
+
- conda-forge
|
| 5 |
+
dependencies:
|
| 6 |
+
- python=3.11
|
| 7 |
+
- omegaconf
|
| 8 |
+
- pip
|
| 9 |
+
- pip:
|
| 10 |
+
- ftfy # needed for dino.txt
|
| 11 |
+
- iopath
|
| 12 |
+
- omegaconf
|
| 13 |
+
- pandas
|
| 14 |
+
- regex # needed for dino.txt
|
| 15 |
+
- pandas
|
| 16 |
+
- scikit-learn
|
| 17 |
+
- scikit-learn-intelex
|
| 18 |
+
- submitit
|
| 19 |
+
- termcolor
|
| 20 |
+
- torch
|
| 21 |
+
- torchvision
|
| 22 |
+
- torchmetrics
|
| 23 |
+
|
InfiniDepth/model/block/torchhub/dinov3/dinov3/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This software may be used and distributed in accordance with
|
| 4 |
+
# the terms of the DINOv3 License Agreement.
|
| 5 |
+
|
| 6 |
+
__version__ = "0.0.1"
|
InfiniDepth/model/block/torchhub/dinov3/dinov3/checkpointer/__init__.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This software may be used and distributed in accordance with
|
| 4 |
+
# the terms of the DINOv3 License Agreement.
|
| 5 |
+
|
| 6 |
+
from .checkpointer import (
|
| 7 |
+
CheckpointRetentionPolicy,
|
| 8 |
+
cleanup_checkpoint,
|
| 9 |
+
find_all_checkpoints,
|
| 10 |
+
find_latest_checkpoint,
|
| 11 |
+
init_fsdp_model_from_checkpoint,
|
| 12 |
+
init_model_from_checkpoint_for_evals,
|
| 13 |
+
keep_checkpoint_copy,
|
| 14 |
+
keep_last_n_checkpoints,
|
| 15 |
+
load_checkpoint,
|
| 16 |
+
register_dont_save_hooks,
|
| 17 |
+
save_checkpoint,
|
| 18 |
+
)
|
InfiniDepth/model/block/torchhub/dinov3/dinov3/checkpointer/checkpointer.py
ADDED
|
@@ -0,0 +1,349 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This software may be used and distributed in accordance with
|
| 4 |
+
# the terms of the DINOv3 License Agreement.
|
| 5 |
+
|
| 6 |
+
"""
|
| 7 |
+
Suggested file structure:
|
| 8 |
+
|
| 9 |
+
output_dir/
|
| 10 |
+
|-- ckpt/
|
| 11 |
+
| |-- 0/
|
| 12 |
+
| |-- 99/
|
| 13 |
+
| |-- 199/
|
| 14 |
+
| |-- 199_keep/
|
| 15 |
+
| |-- 299/
|
| 16 |
+
| `-- ...
|
| 17 |
+
`-- eval/
|
| 18 |
+
`-- 0/
|
| 19 |
+
`-- 99/
|
| 20 |
+
`-- ckpt/
|
| 21 |
+
|
| 22 |
+
Distributed checkpointer docs:
|
| 23 |
+
- https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html
|
| 24 |
+
- https://pytorch.org/docs/stable/distributed.checkpoint.html
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
import logging
|
| 28 |
+
import shutil
|
| 29 |
+
import subprocess
|
| 30 |
+
import tempfile
|
| 31 |
+
from enum import Enum
|
| 32 |
+
from pathlib import Path
|
| 33 |
+
from typing import List, Sequence, Set
|
| 34 |
+
|
| 35 |
+
import torch
|
| 36 |
+
import torch.distributed as dist
|
| 37 |
+
import torch.distributed.checkpoint as dcp
|
| 38 |
+
import torch.distributed.checkpoint.filesystem as dcpfs
|
| 39 |
+
import torch.distributed.checkpoint.state_dict as dcpsd
|
| 40 |
+
from torch.distributed.checkpoint.stateful import Stateful
|
| 41 |
+
|
| 42 |
+
logger = logging.getLogger(__name__)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class CheckpointRetentionPolicy(Enum):
|
| 46 |
+
ALL = "all" # keep all checkpoints
|
| 47 |
+
BEST = "best"
|
| 48 |
+
LAST = "last"
|
| 49 |
+
LAST_AND_BEST = "last_and_best"
|
| 50 |
+
NONE = "none" # do not keep any checkpoints
|
| 51 |
+
|
| 52 |
+
@property
|
| 53 |
+
def keep_filters(self) -> Set[str]:
|
| 54 |
+
"""Files that match these patterns are not deleted by cleanup"""
|
| 55 |
+
if self == CheckpointRetentionPolicy.LAST:
|
| 56 |
+
return set(["final"])
|
| 57 |
+
if self == CheckpointRetentionPolicy.BEST:
|
| 58 |
+
return set(["best"])
|
| 59 |
+
if self == CheckpointRetentionPolicy.LAST_AND_BEST:
|
| 60 |
+
return set(["final", "best"])
|
| 61 |
+
if self == CheckpointRetentionPolicy.ALL:
|
| 62 |
+
return set()
|
| 63 |
+
return set()
|
| 64 |
+
|
| 65 |
+
@property
|
| 66 |
+
def max_to_keep(self) -> int | None:
|
| 67 |
+
"""
|
| 68 |
+
maximum "periodic" checkpoints to keep concurrently, ie. saved with `step` and not `save`. `None` for keep all
|
| 69 |
+
"""
|
| 70 |
+
if self == CheckpointRetentionPolicy.ALL:
|
| 71 |
+
return None
|
| 72 |
+
return 1
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def save_checkpoint(
|
| 76 |
+
ckpt_dir: str | Path, # output_dir/ckpt/199
|
| 77 |
+
*,
|
| 78 |
+
iteration: int | str,
|
| 79 |
+
model: torch.nn.Module,
|
| 80 |
+
optimizer: torch.optim.Optimizer | None = None,
|
| 81 |
+
overwrite: bool = True,
|
| 82 |
+
process_group: dist.ProcessGroup = None,
|
| 83 |
+
**others: Stateful,
|
| 84 |
+
):
|
| 85 |
+
"""Save a plain/DDP/FSDP/FSDP2 model, its optimizer, an integer iteration and other stateful objects."""
|
| 86 |
+
rank = torch.distributed.get_rank(group=process_group)
|
| 87 |
+
|
| 88 |
+
# Rank 0 checks if the checkpoint directory exists, but all ranks need to know if if exists,
|
| 89 |
+
# so they can raise an error when overwrite is False. If overwrite is True, rank 0 will delete it
|
| 90 |
+
# and other ranks wait for the deletion to finish.
|
| 91 |
+
ckpt_dir = Path(ckpt_dir)
|
| 92 |
+
ckpt_dir_exists = [ckpt_dir.exists() if rank == 0 else None]
|
| 93 |
+
src_rank = 0
|
| 94 |
+
if process_group is not None:
|
| 95 |
+
src_rank = torch.distributed.get_global_rank(group=process_group, group_rank=0)
|
| 96 |
+
torch.distributed.broadcast_object_list(ckpt_dir_exists, src=src_rank, group=process_group)
|
| 97 |
+
ckpt_dir_exists = ckpt_dir_exists[0]
|
| 98 |
+
if ckpt_dir_exists:
|
| 99 |
+
if overwrite:
|
| 100 |
+
if rank == 0:
|
| 101 |
+
if ckpt_dir.is_dir():
|
| 102 |
+
shutil.rmtree(ckpt_dir)
|
| 103 |
+
else:
|
| 104 |
+
ckpt_dir.unlink()
|
| 105 |
+
logger.info(f"Deleted: {ckpt_dir}")
|
| 106 |
+
torch.distributed.barrier(group=process_group)
|
| 107 |
+
else:
|
| 108 |
+
raise RuntimeError(f"Checkpoint already exists: {ckpt_dir}")
|
| 109 |
+
|
| 110 |
+
# Rank 0 creates a temporary directory for the checkpoint and broadcasts the name to all ranks.
|
| 111 |
+
ckpt_dir.parent.mkdir(parents=True, exist_ok=True)
|
| 112 |
+
ckpt_dir_tmp = [tempfile.mkdtemp(dir=ckpt_dir.parent, prefix=ckpt_dir.name) if rank == 0 else None]
|
| 113 |
+
torch.distributed.broadcast_object_list(ckpt_dir_tmp, src=src_rank, group=process_group)
|
| 114 |
+
ckpt_dir_tmp = Path(ckpt_dir_tmp[0])
|
| 115 |
+
|
| 116 |
+
to_save = {"iteration": iteration}
|
| 117 |
+
to_save["model"] = dcpsd.get_model_state_dict(model)
|
| 118 |
+
if optimizer is not None:
|
| 119 |
+
to_save["optimizer"] = dcpsd.get_optimizer_state_dict(model, optimizer)
|
| 120 |
+
to_save.update(others)
|
| 121 |
+
dcp.save(
|
| 122 |
+
to_save,
|
| 123 |
+
storage_writer=dcpfs.FileSystemWriter(ckpt_dir_tmp),
|
| 124 |
+
process_group=process_group,
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
# Rank 0 renames the temporary directory to the final checkpoint directory. All ranks wait for the rename.
|
| 128 |
+
if rank == 0:
|
| 129 |
+
ckpt_dir_tmp.rename(ckpt_dir)
|
| 130 |
+
torch.distributed.barrier()
|
| 131 |
+
|
| 132 |
+
logger.info(f"Saved: {ckpt_dir}")
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def load_checkpoint(
|
| 136 |
+
ckpt_dir: str | Path, # output_dir/ckpt/199
|
| 137 |
+
*,
|
| 138 |
+
model: torch.nn.Module,
|
| 139 |
+
optimizer: torch.optim.Optimizer | None = None,
|
| 140 |
+
strict_loading: bool = True,
|
| 141 |
+
process_group: dist.ProcessGroup = None,
|
| 142 |
+
**others: Stateful,
|
| 143 |
+
) -> int | None:
|
| 144 |
+
"""
|
| 145 |
+
Load a plain/DDP/FSDP/FSDP2 model, its optimizer, an integer iteration and other stateful objects.
|
| 146 |
+
Can you take a checkpoint saved on N ranks and load it on M ranks? Sure you can!
|
| 147 |
+
Activation checkpointing and torch-compile can also be different between save and load, no problem.
|
| 148 |
+
"""
|
| 149 |
+
ckpt_dir = Path(ckpt_dir)
|
| 150 |
+
to_load = {"iteration": None}
|
| 151 |
+
to_load["model"] = dcpsd.get_model_state_dict(model)
|
| 152 |
+
if optimizer is not None:
|
| 153 |
+
to_load["optimizer"] = dcpsd.get_optimizer_state_dict(model, optimizer)
|
| 154 |
+
to_load.update(others)
|
| 155 |
+
dcp.load(
|
| 156 |
+
to_load,
|
| 157 |
+
storage_reader=dcpfs.FileSystemReader(ckpt_dir),
|
| 158 |
+
planner=dcp.default_planner.DefaultLoadPlanner(allow_partial_load=not strict_loading),
|
| 159 |
+
process_group=process_group,
|
| 160 |
+
)
|
| 161 |
+
iteration = to_load["iteration"]
|
| 162 |
+
dcpsd.set_model_state_dict(model, to_load["model"])
|
| 163 |
+
if optimizer is not None:
|
| 164 |
+
dcpsd.set_optimizer_state_dict(model, optimizer, to_load["optimizer"])
|
| 165 |
+
logger.info(f"Loaded: {ckpt_dir}")
|
| 166 |
+
return iteration
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def register_dont_save_hooks(module: torch.nn.Module, dont_save: Sequence[str]):
|
| 170 |
+
"""
|
| 171 |
+
Registers save/load state dict hooks such that the weights in `dont_save` are not persisted in the checkpoint.
|
| 172 |
+
|
| 173 |
+
Typical use case: a classification model composed of a frozen backbone and a trainable head.
|
| 174 |
+
If the frozen backbone is loaded from torch hub, it does't make sense to save a copy of it in each checkpoint.
|
| 175 |
+
"""
|
| 176 |
+
|
| 177 |
+
def state_dict_post_hook(module, state_dict, prefix, local_metadata):
|
| 178 |
+
# Remove frozen weights so they won't get saved.
|
| 179 |
+
# If this module is not the top-level module, its weights will have a prefix in the state dict.
|
| 180 |
+
nonlocal _dont_save
|
| 181 |
+
for k in _dont_save:
|
| 182 |
+
del state_dict[prefix + k]
|
| 183 |
+
|
| 184 |
+
def load_state_dict_pre_hook(
|
| 185 |
+
module,
|
| 186 |
+
state_dict,
|
| 187 |
+
prefix,
|
| 188 |
+
local_metadata,
|
| 189 |
+
strict,
|
| 190 |
+
missing_keys,
|
| 191 |
+
unexpected_keys,
|
| 192 |
+
error_msgs,
|
| 193 |
+
):
|
| 194 |
+
# This pre hook exists only to pass the prefix to the post hook when loading the state dict.
|
| 195 |
+
nonlocal _prefix
|
| 196 |
+
assert _prefix is None
|
| 197 |
+
_prefix = prefix
|
| 198 |
+
|
| 199 |
+
def load_state_dict_post_hook(module, incompatible_keys):
|
| 200 |
+
# Remove the frozen weights from the missing keys so they don't raise an error.
|
| 201 |
+
nonlocal _prefix
|
| 202 |
+
assert _prefix is not None
|
| 203 |
+
to_remove = []
|
| 204 |
+
for missing_key in incompatible_keys.missing_keys:
|
| 205 |
+
k = missing_key.removeprefix(_prefix)
|
| 206 |
+
k = k.replace("_checkpoint_wrapped_module.", "") # Added by activation checkpointing
|
| 207 |
+
if k in _dont_save:
|
| 208 |
+
to_remove.append(missing_key)
|
| 209 |
+
for r in to_remove:
|
| 210 |
+
incompatible_keys.missing_keys.remove(r)
|
| 211 |
+
_prefix = None
|
| 212 |
+
|
| 213 |
+
_dont_save = set(name.replace("_checkpoint_wrapped_module.", "") for name in dont_save)
|
| 214 |
+
_prefix = None
|
| 215 |
+
module.register_state_dict_post_hook(state_dict_post_hook)
|
| 216 |
+
module.register_load_state_dict_pre_hook(load_state_dict_pre_hook)
|
| 217 |
+
module.register_load_state_dict_post_hook(load_state_dict_post_hook)
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def find_all_checkpoints(ckpt_dir: Path | str) -> list[Path]:
|
| 221 |
+
"""Find all checkpoints in a directory, i.e. subdirs with integer name. Sorted from first to last."""
|
| 222 |
+
ckpt_dir = Path(ckpt_dir)
|
| 223 |
+
if not ckpt_dir.is_dir():
|
| 224 |
+
return []
|
| 225 |
+
checkpoints = [p for p in ckpt_dir.iterdir() if p.is_dir() and _is_int(p.name)]
|
| 226 |
+
checkpoints.sort(key=lambda p: int(p.name))
|
| 227 |
+
return checkpoints
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
def find_latest_checkpoint(ckpt_dir: Path | str) -> Path | None:
|
| 231 |
+
"""Find the latest checkpoint in a directory, i.e. the subdir with the highest integer name."""
|
| 232 |
+
checkpoints = find_all_checkpoints(ckpt_dir)
|
| 233 |
+
if len(checkpoints) == 0:
|
| 234 |
+
return None
|
| 235 |
+
return checkpoints[-1]
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
def keep_last_n_checkpoints(ckpt_dir: Path | str, n: int | None):
|
| 239 |
+
"""In a directory with integer-named subdirs, keep only the n subdirs with the highest number."""
|
| 240 |
+
if n is None:
|
| 241 |
+
return
|
| 242 |
+
checkpoints = find_all_checkpoints(ckpt_dir)
|
| 243 |
+
for ckpt_dir in checkpoints[:-n]:
|
| 244 |
+
try:
|
| 245 |
+
shutil.rmtree(ckpt_dir)
|
| 246 |
+
logger.info(f"Deleted: {ckpt_dir}")
|
| 247 |
+
except Exception:
|
| 248 |
+
logger.exception(f"Failed to delete: {ckpt_dir}")
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
def keep_checkpoint_copy(src: Path | str):
|
| 252 |
+
"""Copy a file/directory next to itself with a _keep suffix. Files are hardlinked."""
|
| 253 |
+
src = Path(src)
|
| 254 |
+
dst = src.parent / f"{src.name}_keep"
|
| 255 |
+
subprocess.check_output(["cp", "--recursive", "--link", src, dst])
|
| 256 |
+
logger.info(f"Copied: {src} -> {dst}")
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
def _is_int(s: str) -> bool:
|
| 260 |
+
try:
|
| 261 |
+
int(s)
|
| 262 |
+
return True
|
| 263 |
+
except ValueError:
|
| 264 |
+
return False
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
# Initialize a FSDP2 model from DCP or PyTorch standard checkpoint
|
| 268 |
+
def init_fsdp_model_from_checkpoint(
|
| 269 |
+
model: torch.nn.Module,
|
| 270 |
+
checkpoint_path: str,
|
| 271 |
+
skip_load_prefixes: List[str] | None = None,
|
| 272 |
+
prefixes_not_sharded: List[str] | None = None,
|
| 273 |
+
process_group: dist.ProcessGroup = None,
|
| 274 |
+
):
|
| 275 |
+
if not Path(checkpoint_path).is_dir(): # PyTorch standard checkpoint
|
| 276 |
+
logger.info(f"Loading pretrained weights from {checkpoint_path}")
|
| 277 |
+
chkpt = torch.load(checkpoint_path, map_location="cpu")["teacher"]
|
| 278 |
+
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
|
| 279 |
+
|
| 280 |
+
if process_group is None:
|
| 281 |
+
world_mesh = init_device_mesh(
|
| 282 |
+
"cuda",
|
| 283 |
+
mesh_shape=(dist.get_world_size(),),
|
| 284 |
+
mesh_dim_names=("dp",),
|
| 285 |
+
)
|
| 286 |
+
else:
|
| 287 |
+
world_mesh = DeviceMesh.from_group(process_group, "cuda")
|
| 288 |
+
chkpt = {
|
| 289 |
+
k: (
|
| 290 |
+
torch.distributed.tensor.distribute_tensor(v, world_mesh, src_data_rank=None)
|
| 291 |
+
if not k.startswith(pns)
|
| 292 |
+
else v
|
| 293 |
+
)
|
| 294 |
+
for pns in prefixes_not_sharded
|
| 295 |
+
for k, v in chkpt.items()
|
| 296 |
+
}
|
| 297 |
+
model.load_state_dict(
|
| 298 |
+
{k: v for k, v in chkpt.items() if not any(k.startswith(prefix) for prefix in skip_load_prefixes)}
|
| 299 |
+
)
|
| 300 |
+
else: # DCP checkpoint
|
| 301 |
+
load_checkpoint(ckpt_dir=checkpoint_path, model=model, process_group=process_group)
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
# Initialize a standard non distributed PyTorch model from PyTorch standard checkpoint for evals
|
| 305 |
+
def init_model_from_checkpoint_for_evals(
|
| 306 |
+
model: torch.nn.Module, pretrained_weights: str | Path, checkpoint_key: str = None
|
| 307 |
+
):
|
| 308 |
+
state_dict = torch.load(pretrained_weights, map_location="cpu")
|
| 309 |
+
if checkpoint_key is not None and checkpoint_key in state_dict:
|
| 310 |
+
logger.info(f"Take key {checkpoint_key} in provided checkpoint dict")
|
| 311 |
+
state_dict = state_dict[checkpoint_key]
|
| 312 |
+
# remove `module.` prefix
|
| 313 |
+
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
|
| 314 |
+
# remove `backbone.` prefix induced by multicrop wrapper
|
| 315 |
+
state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
|
| 316 |
+
msg = model.load_state_dict(state_dict, strict=False)
|
| 317 |
+
logger.info("Pretrained weights found at {} and loaded with msg: {}".format(pretrained_weights, msg))
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
def cleanup_checkpoint(ckpt_dir: str, checkpoint_retention_policy: CheckpointRetentionPolicy):
|
| 321 |
+
"""
|
| 322 |
+
ckpt_dir is the directory containing each individual checkpoint directories (either at iteration, best (validation performance) or final)
|
| 323 |
+
|-- ckpt_dir/
|
| 324 |
+
| |-- 0/
|
| 325 |
+
| |--checkpoint.pth or dcp_sharded_checkpoint_dir
|
| 326 |
+
| |-- 99/
|
| 327 |
+
|--checkpoint.pth or dcp_sharded_checkpoint_dir
|
| 328 |
+
| |-- 199/
|
| 329 |
+
|--checkpoint.pth or dcp_sharded_checkpoint_dir
|
| 330 |
+
| |-- best/
|
| 331 |
+
|--checkpoint.pth or dcp_sharded_checkpoint_dir
|
| 332 |
+
| |-- 299/
|
| 333 |
+
|--checkpoint.pth or dcp_sharded_checkpoint_dir
|
| 334 |
+
| |-- final/
|
| 335 |
+
|--checkpoint.pth or dcp_sharded_checkpoint_dir
|
| 336 |
+
"""
|
| 337 |
+
ckpt_dir = Path(ckpt_dir)
|
| 338 |
+
if not ckpt_dir.is_dir():
|
| 339 |
+
return []
|
| 340 |
+
checkpoint_filters = checkpoint_retention_policy.keep_filters
|
| 341 |
+
checkpoints = [p for p in ckpt_dir.iterdir() if p.is_dir()]
|
| 342 |
+
for checkpoint in checkpoints:
|
| 343 |
+
if checkpoint in checkpoint_filters:
|
| 344 |
+
continue
|
| 345 |
+
try:
|
| 346 |
+
shutil.rmtree(checkpoint)
|
| 347 |
+
logger.info(f"Deleted: {checkpoint}")
|
| 348 |
+
except Exception:
|
| 349 |
+
logger.exception(f"Failed to delete: {checkpoint}")
|
InfiniDepth/model/block/torchhub/dinov3/dinov3/configs/__init__.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This software may be used and distributed in accordance with
|
| 4 |
+
# the terms of the DINOv3 License Agreement.
|
| 5 |
+
|
| 6 |
+
from .config import (
|
| 7 |
+
DinoV3SetupArgs,
|
| 8 |
+
apply_scaling_rules_to_cfg,
|
| 9 |
+
exit_job,
|
| 10 |
+
get_cfg_from_args,
|
| 11 |
+
get_default_config,
|
| 12 |
+
setup_config,
|
| 13 |
+
setup_job,
|
| 14 |
+
setup_multidistillation,
|
| 15 |
+
write_config,
|
| 16 |
+
)
|
InfiniDepth/model/block/torchhub/dinov3/dinov3/configs/config.py
ADDED
|
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This software may be used and distributed in accordance with
|
| 4 |
+
# the terms of the DINOv3 License Agreement.
|
| 5 |
+
|
| 6 |
+
import logging
|
| 7 |
+
import math
|
| 8 |
+
import os
|
| 9 |
+
import pathlib
|
| 10 |
+
import sys
|
| 11 |
+
from dataclasses import dataclass, field
|
| 12 |
+
from datetime import timedelta
|
| 13 |
+
from typing import Any, List, Optional, Sequence, Tuple
|
| 14 |
+
|
| 15 |
+
from omegaconf import DictConfig, OmegaConf
|
| 16 |
+
|
| 17 |
+
import dinov3.distributed as distributed
|
| 18 |
+
from dinov3.logging import cleanup_logging, setup_logging
|
| 19 |
+
from dinov3.utils import fix_random_seeds, get_conda_env, get_sha
|
| 20 |
+
|
| 21 |
+
logger = logging.getLogger("dinov3")
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@dataclass
|
| 25 |
+
class DinoV3SetupArgs:
|
| 26 |
+
config_file: str
|
| 27 |
+
pretrained_weights: str | None = None
|
| 28 |
+
shard_unsharded_model: bool = False
|
| 29 |
+
output_dir: str = ""
|
| 30 |
+
opts: List[Any] = field(default_factory=lambda: [])
|
| 31 |
+
|
| 32 |
+
def __post_init__(self):
|
| 33 |
+
# When loaded from benchmark.yaml, self.opts is a frozen omegaconf.ListConfig,
|
| 34 |
+
# which works everywhere except when we want to modify it or when
|
| 35 |
+
# we try to json-serialize it. So we convert it to a regular list here.
|
| 36 |
+
if OmegaConf.is_config(self.opts):
|
| 37 |
+
self.opts = OmegaConf.to_object(self.opts)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def apply_scaling_rules_to_cfg(cfg): # to fix
|
| 41 |
+
assert distributed.is_enabled(), "Setup distributed to get global size !"
|
| 42 |
+
if "schedules" in cfg:
|
| 43 |
+
# For schedules v2, the scaling rules are applied when building the schedules, the config is not modified
|
| 44 |
+
return cfg
|
| 45 |
+
|
| 46 |
+
if cfg.optim.scaling_rule == "linear_wrt_256":
|
| 47 |
+
old_lr = cfg.optim.lr
|
| 48 |
+
cfg.optim.lr *= cfg.train.batch_size_per_gpu * distributed.get_world_size() / 256.0
|
| 49 |
+
logger.info(f"linear scaling learning rate; old: {old_lr}, new: {cfg.optim.lr}")
|
| 50 |
+
elif cfg.optim.scaling_rule == "sqrt_wrt_1024":
|
| 51 |
+
old_lr = cfg.optim.lr
|
| 52 |
+
cfg.optim.lr *= 4 * math.sqrt(cfg.train.batch_size_per_gpu * distributed.get_world_size() / 1024.0)
|
| 53 |
+
logger.info(f"sqrt scaling learning rate; old: {old_lr}, new: {cfg.optim.lr}")
|
| 54 |
+
return cfg
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def write_config(cfg, output_dir, name="config.yaml"):
|
| 58 |
+
logger.info(OmegaConf.to_yaml(cfg))
|
| 59 |
+
output_dir = os.path.abspath(output_dir)
|
| 60 |
+
saved_cfg_path = os.path.join(output_dir, name)
|
| 61 |
+
with open(saved_cfg_path, "w") as f:
|
| 62 |
+
OmegaConf.save(config=cfg, f=f)
|
| 63 |
+
return saved_cfg_path
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def get_default_config() -> DictConfig:
|
| 67 |
+
p = pathlib.Path(__file__).parent / "ssl_default_config.yaml"
|
| 68 |
+
return OmegaConf.load(p)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def get_cfg_from_args(args: DinoV3SetupArgs, multidistillation=False, strict=True):
|
| 72 |
+
overrides = [*args.opts]
|
| 73 |
+
if args.output_dir is not None:
|
| 74 |
+
overrides.append(f"train.output_dir={os.path.realpath(args.output_dir)}")
|
| 75 |
+
|
| 76 |
+
# Config file
|
| 77 |
+
cfg = OmegaConf.load(args.config_file)
|
| 78 |
+
|
| 79 |
+
# Command line overrides
|
| 80 |
+
opts_cfg = OmegaConf.from_cli(overrides)
|
| 81 |
+
|
| 82 |
+
if multidistillation:
|
| 83 |
+
cfg = OmegaConf.merge(cfg, opts_cfg)
|
| 84 |
+
else:
|
| 85 |
+
# Default config
|
| 86 |
+
default_cfg = get_default_config()
|
| 87 |
+
if strict:
|
| 88 |
+
OmegaConf.set_struct(default_cfg, True)
|
| 89 |
+
cfg = OmegaConf.merge(default_cfg, cfg, opts_cfg)
|
| 90 |
+
return cfg
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def setup_config(args: DinoV3SetupArgs, strict_cfg=True):
|
| 94 |
+
"""
|
| 95 |
+
Create configs and perform basic setups.
|
| 96 |
+
"""
|
| 97 |
+
# Create the cfg with OmegaConf
|
| 98 |
+
cfg = get_cfg_from_args(args, strict=strict_cfg)
|
| 99 |
+
# setup distributed, logging, and random seeds
|
| 100 |
+
logger.info("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items())))
|
| 101 |
+
# dump config before modifying so it can be reloaded
|
| 102 |
+
if args.output_dir is not None:
|
| 103 |
+
write_config(cfg, args.output_dir)
|
| 104 |
+
# modify the config inplace by applying scaling rules
|
| 105 |
+
apply_scaling_rules_to_cfg(cfg)
|
| 106 |
+
return cfg
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def _enumerate_all_subgroup_ranks(all_subgroup_rank_spans: Sequence[Tuple[int, int]]):
|
| 110 |
+
"""Expands a specification of process subgroups from spans to enumerated ranks.
|
| 111 |
+
|
| 112 |
+
Args:
|
| 113 |
+
all_group_rank_spans: a sequence of rank spans (first rank, last rank),
|
| 114 |
+
one for each process group. Example: ((0, 1), (2, 3), (4, 7)).
|
| 115 |
+
"""
|
| 116 |
+
for first, last in all_subgroup_rank_spans:
|
| 117 |
+
assert first <= last
|
| 118 |
+
return tuple(tuple(range(first, last + 1)) for first, last in all_subgroup_rank_spans)
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def setup_multidistillation(args: DinoV3SetupArgs):
|
| 122 |
+
base_output_dir = args.output_dir
|
| 123 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 124 |
+
# get config file for this rank
|
| 125 |
+
base_cfg = OmegaConf.load(args.config_file)
|
| 126 |
+
assert base_cfg.multidistillation.enabled
|
| 127 |
+
|
| 128 |
+
global_batch_size = base_cfg.multidistillation.global_batch_size
|
| 129 |
+
|
| 130 |
+
distributed.enable(overwrite=True)
|
| 131 |
+
seed = getattr(args, "seed", 0)
|
| 132 |
+
rank = distributed.get_rank()
|
| 133 |
+
|
| 134 |
+
# build process subgroups
|
| 135 |
+
all_subgroup_rank_spans = tuple(
|
| 136 |
+
(student.ranks_range[0], student.ranks_range[1] - 1) for student in base_cfg.multidistillation.students
|
| 137 |
+
)
|
| 138 |
+
all_subgroup_ranks = _enumerate_all_subgroup_ranks(all_subgroup_rank_spans)
|
| 139 |
+
distributed.new_subgroups(all_subgroup_ranks)
|
| 140 |
+
|
| 141 |
+
found = False
|
| 142 |
+
for student in base_cfg.multidistillation.students:
|
| 143 |
+
if rank in range(*student.ranks_range):
|
| 144 |
+
found = True
|
| 145 |
+
break
|
| 146 |
+
assert found, "rank of worker not in defined range"
|
| 147 |
+
|
| 148 |
+
name = student.name
|
| 149 |
+
config_path = student.config_path
|
| 150 |
+
n_gpus = student.ranks_range[1] - student.ranks_range[0]
|
| 151 |
+
assert global_batch_size % n_gpus == 0
|
| 152 |
+
total_n_gpus = distributed.get_world_size()
|
| 153 |
+
|
| 154 |
+
args.output_dir = os.path.join(base_output_dir, name)
|
| 155 |
+
args.opts += [f"train.output_dir={args.output_dir}"]
|
| 156 |
+
args.opts += [f"train.batch_size_per_gpu={global_batch_size // total_n_gpus}"]
|
| 157 |
+
args.config_file = os.path.abspath(config_path)
|
| 158 |
+
default_cfg = get_default_config()
|
| 159 |
+
cfg = OmegaConf.load(args.config_file)
|
| 160 |
+
cfg = OmegaConf.merge(default_cfg, cfg, base_cfg, OmegaConf.from_cli(args.opts))
|
| 161 |
+
|
| 162 |
+
global logger
|
| 163 |
+
setup_logging(output=args.output_dir, level=logging.INFO)
|
| 164 |
+
|
| 165 |
+
fix_random_seeds(seed + rank)
|
| 166 |
+
|
| 167 |
+
write_config(cfg, args.output_dir)
|
| 168 |
+
apply_scaling_rules_to_cfg(cfg)
|
| 169 |
+
|
| 170 |
+
return cfg
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def setup_job(
|
| 174 |
+
output_dir: Optional[str] = None,
|
| 175 |
+
distributed_enabled: bool = True,
|
| 176 |
+
logging_enabled: bool = True,
|
| 177 |
+
seed: Optional[int] = 0,
|
| 178 |
+
restrict_print_to_main_process: bool = True,
|
| 179 |
+
distributed_timeout: timedelta | None = None,
|
| 180 |
+
):
|
| 181 |
+
"""
|
| 182 |
+
Setup methods that should be done in every fairvit job
|
| 183 |
+
Initializes logging, distributed, random seeds and other utilities.
|
| 184 |
+
"""
|
| 185 |
+
if output_dir is not None:
|
| 186 |
+
output_dir = os.path.realpath(output_dir)
|
| 187 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 188 |
+
|
| 189 |
+
if logging_enabled:
|
| 190 |
+
setup_logging(
|
| 191 |
+
output=output_dir,
|
| 192 |
+
level=logging.INFO,
|
| 193 |
+
log_to_stdout_only_in_main_process=restrict_print_to_main_process,
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
if distributed_enabled:
|
| 197 |
+
distributed.enable(
|
| 198 |
+
overwrite=True,
|
| 199 |
+
nccl_async_error_handling=True,
|
| 200 |
+
restrict_print_to_main_process=restrict_print_to_main_process,
|
| 201 |
+
timeout=distributed_timeout,
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
if seed is not None:
|
| 205 |
+
rank = distributed.get_rank()
|
| 206 |
+
fix_random_seeds(seed + rank)
|
| 207 |
+
|
| 208 |
+
logger = logging.getLogger("fairvit")
|
| 209 |
+
logger.info("git:\n {}\n".format(get_sha()))
|
| 210 |
+
|
| 211 |
+
# Log some python info
|
| 212 |
+
conda_env_name, conda_env_path = get_conda_env()
|
| 213 |
+
logger.info(f"conda env name: {conda_env_name}")
|
| 214 |
+
logger.info(f"conda env path: {conda_env_path}")
|
| 215 |
+
logger.info(f"python path: {sys.path}")
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
def exit_job(distributed_enabled: bool = True, logging_enabled: bool = True):
|
| 219 |
+
if distributed_enabled:
|
| 220 |
+
distributed.disable()
|
| 221 |
+
if logging_enabled:
|
| 222 |
+
cleanup_logging()
|
InfiniDepth/model/block/torchhub/dinov3/dinov3/configs/ssl_default_config.yaml
ADDED
|
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MODEL:
|
| 2 |
+
META_ARCHITECTURE: SSLMetaArch
|
| 3 |
+
DEVICE: cuda
|
| 4 |
+
WEIGHTS: ''
|
| 5 |
+
DTYPE: float32
|
| 6 |
+
compute_precision:
|
| 7 |
+
param_dtype: bf16
|
| 8 |
+
reduce_dtype: fp32
|
| 9 |
+
sharding_strategy: SHARD_GRAD_OP
|
| 10 |
+
dino:
|
| 11 |
+
loss_weight: 1.0
|
| 12 |
+
global_ignore_diagonal: true # Whether to ignore A-A and B-B global pairs, default as in DINOv2, ignored by SSLMetaArch
|
| 13 |
+
head_n_prototypes: 65536
|
| 14 |
+
head_bottleneck_dim: 256
|
| 15 |
+
head_norm_last_layer: false
|
| 16 |
+
head_nlayers: 3
|
| 17 |
+
head_hidden_dim: 2048
|
| 18 |
+
koleo_loss_weight: 0.1
|
| 19 |
+
koleo_loss_distributed: false
|
| 20 |
+
koleo_topk: 1
|
| 21 |
+
koleo_distributed_replicas: 0
|
| 22 |
+
koleo_distributed_loss_group_size: null # Size of the nearest neighbor set for distributed Koleo. If None, uses global batch size.
|
| 23 |
+
koleo_distributed_loss_group_data: true # group data from adjacent ranks to make sure koleo is applied on the same data distribution
|
| 24 |
+
force_weight_norm: false
|
| 25 |
+
reweight_dino_local_loss: false # If true, reweighting of DINO loss
|
| 26 |
+
local_loss_weight_schedule: # Schedule for local loss weight, enabled if reweight_dino_local_loss is true
|
| 27 |
+
start: 0.5
|
| 28 |
+
peak: 0.5
|
| 29 |
+
end: 0.5
|
| 30 |
+
warmup_epochs: 0
|
| 31 |
+
ibot:
|
| 32 |
+
loss_weight: 1.0
|
| 33 |
+
mask_sample_probability: 0.5
|
| 34 |
+
mask_ratio_min_max:
|
| 35 |
+
- 0.1
|
| 36 |
+
- 0.5
|
| 37 |
+
mask_random_circular_shift: false
|
| 38 |
+
force_masking_even_with_zero_weight: False
|
| 39 |
+
separate_head: true
|
| 40 |
+
head_n_prototypes: 65536
|
| 41 |
+
head_bottleneck_dim: 256
|
| 42 |
+
head_norm_last_layer: false
|
| 43 |
+
head_nlayers: 3
|
| 44 |
+
head_hidden_dim: 2048
|
| 45 |
+
gram:
|
| 46 |
+
use_loss: false # (bool) if true gram is used, else not
|
| 47 |
+
compute_stats: false # (bool): if true compute auxilliary stats
|
| 48 |
+
loss_weight: 1.0 # (float): weight of the loss
|
| 49 |
+
ema_teacher: false # (bool): using the EMA teacher as GRAM teacher
|
| 50 |
+
ckpt: null #(str): Checkpoint to the teacher
|
| 51 |
+
it_load_ema_teacher: -1 # (int): iteration at which the ema teacher is loaded into the gram teacher
|
| 52 |
+
rep_update: true # (bool): if true GRAM teacher updated every gram.update_frequency after iter gram.it_first_update steps
|
| 53 |
+
update_frequency: 50000 # (int): update frequency
|
| 54 |
+
it_first_update: 0 # (int): iteration of the first update
|
| 55 |
+
max_updates: null # (int): maximum number of updates to gram teacher. If None, it is unlimited
|
| 56 |
+
normalized: true # (bool): normalization of the features
|
| 57 |
+
img_level: false # (bool): if true GRAM computation at the image else, otherwise at the local batch level
|
| 58 |
+
remove_neg: false # (bool): if true remove the negative similarities before applying the loss
|
| 59 |
+
remove_only_teacher_neg: false # (bool): remove negative similarities of the teacher
|
| 60 |
+
tokens_used: all # (str): In [all, masked, unmasked]
|
| 61 |
+
global_teacher_resize_method: bicubic # Method for resizing the outputs of the gram teacher
|
| 62 |
+
global_teacher_resize_antialias: false # Whether to use antialiasing when resizing the outputs of the gram teacher
|
| 63 |
+
loss_weight_schedule: null # (dict): If not None, use a schedule for the loss weight instead of `loss_weight`
|
| 64 |
+
train:
|
| 65 |
+
batch_size_per_gpu: 64
|
| 66 |
+
dataset_path: ImageNet:split=TRAIN
|
| 67 |
+
data_config: null
|
| 68 |
+
output_dir: .
|
| 69 |
+
saveckp_freq: 20
|
| 70 |
+
seed: 0
|
| 71 |
+
num_workers: 10
|
| 72 |
+
OFFICIAL_EPOCH_LENGTH: 1250
|
| 73 |
+
monitor_gradient_norm: false
|
| 74 |
+
chunk_schedule: []
|
| 75 |
+
use_teacher_head: true
|
| 76 |
+
learn_from_teacher_tokens: false
|
| 77 |
+
centering: "sinkhorn_knopp" # or "sinkhorn_knopp"
|
| 78 |
+
checkpointing: false
|
| 79 |
+
checkpointing_full: false # aggressive checkpointing
|
| 80 |
+
compile: true
|
| 81 |
+
cudagraphs: false
|
| 82 |
+
sharded_eval_checkpoint: false
|
| 83 |
+
cache_dataset: false
|
| 84 |
+
student:
|
| 85 |
+
arch: vit_large
|
| 86 |
+
patch_size: 16
|
| 87 |
+
drop_path_rate: 0.3
|
| 88 |
+
layerscale: 1.0e-05
|
| 89 |
+
pretrained_weights: ''
|
| 90 |
+
ffn_layer: "mlp"
|
| 91 |
+
ffn_ratio: 4.0
|
| 92 |
+
resume_from_teacher_chkpt: ""
|
| 93 |
+
qkv_bias: true
|
| 94 |
+
proj_bias: true
|
| 95 |
+
ffn_bias: true
|
| 96 |
+
norm_layer: "layernorm"
|
| 97 |
+
n_storage_tokens: 0
|
| 98 |
+
mask_k_bias: false
|
| 99 |
+
untie_cls_and_patch_norms: false # If true, use separate norms for CLS/reg and patch/mask tokens
|
| 100 |
+
untie_global_and_local_cls_norm: false # If true, use separate norms for local and global crop CLS token during training
|
| 101 |
+
in_chans: 3
|
| 102 |
+
pos_embed_type: rope
|
| 103 |
+
pos_embed_rope_base: 100.0
|
| 104 |
+
pos_embed_rope_min_period: null
|
| 105 |
+
pos_embed_rope_max_period: null
|
| 106 |
+
pos_embed_rope_normalize_coords: separate # min, max, separate
|
| 107 |
+
pos_embed_rope_shift_coords: null
|
| 108 |
+
pos_embed_rope_jitter_coords: null
|
| 109 |
+
pos_embed_rope_rescale_coords: null
|
| 110 |
+
pos_embed_rope_dtype: bf16
|
| 111 |
+
fp8_enabled: False # Convert Linear layers to operate in fp8 precision
|
| 112 |
+
fp8_filter: "blocks" # Regex that must appear in module path; empty means everything
|
| 113 |
+
teacher:
|
| 114 |
+
momentum_teacher: 0.992
|
| 115 |
+
final_momentum_teacher: 1
|
| 116 |
+
warmup_teacher_temp: 0.04
|
| 117 |
+
teacher_temp: 0.07
|
| 118 |
+
warmup_teacher_temp_epochs: 30
|
| 119 |
+
in_chans: 3
|
| 120 |
+
distillation: # teacher
|
| 121 |
+
enabled: false
|
| 122 |
+
full_cfg_path: ""
|
| 123 |
+
checkpoint_path: ""
|
| 124 |
+
multidistillation:
|
| 125 |
+
enabled: false
|
| 126 |
+
hrft: # non-hrft'd student
|
| 127 |
+
enabled: false
|
| 128 |
+
checkpoint_path: "" # teacher_checkpoint path
|
| 129 |
+
optim:
|
| 130 |
+
epochs: 100
|
| 131 |
+
optimizer: adamw
|
| 132 |
+
weight_decay: 0.04
|
| 133 |
+
weight_decay_end: 0.4
|
| 134 |
+
lr: 0.001
|
| 135 |
+
warmup_epochs: 10
|
| 136 |
+
min_lr: 1.0e-06
|
| 137 |
+
schedule_trunc_extra: 0.0 # Compute the schedule for (1 + schedule_trunc_extra) steps and truncate, .25 is a good choice
|
| 138 |
+
clip_grad: 3.0
|
| 139 |
+
freeze_last_layer_epochs: 1
|
| 140 |
+
scaling_rule: sqrt_wrt_1024
|
| 141 |
+
patch_embed_lr_mult: 0.2
|
| 142 |
+
dino_head_wd_multiplier: 1.0
|
| 143 |
+
layerwise_decay: 0.9
|
| 144 |
+
multi_tensor_optim: true
|
| 145 |
+
dump_fsdp_weights_path: ""
|
| 146 |
+
adamw_beta1: 0.9
|
| 147 |
+
adamw_beta2: 0.999
|
| 148 |
+
crops:
|
| 149 |
+
global_crops_scale:
|
| 150 |
+
- 0.32
|
| 151 |
+
- 1.0
|
| 152 |
+
local_crops_number: 8
|
| 153 |
+
local_crops_scale:
|
| 154 |
+
- 0.05
|
| 155 |
+
- 0.32
|
| 156 |
+
global_crops_size: 224
|
| 157 |
+
local_crops_size: 96
|
| 158 |
+
global_local_crop_pairs_ratios: 1.0
|
| 159 |
+
gram_teacher_crops_size: null # If not None, return crops for gram teacher
|
| 160 |
+
localcrops_subset_of_globalcrops: false
|
| 161 |
+
share_color_jitter: false
|
| 162 |
+
horizontal_flips: true
|
| 163 |
+
gram_teacher_no_distortions: false # If True, no distortions are applied to gram teacher crops
|
| 164 |
+
rgb_mean:
|
| 165 |
+
- 0.485
|
| 166 |
+
- 0.456
|
| 167 |
+
- 0.406
|
| 168 |
+
rgb_std:
|
| 169 |
+
- 0.229
|
| 170 |
+
- 0.224
|
| 171 |
+
- 0.225
|
| 172 |
+
evaluation:
|
| 173 |
+
eval_period_iterations: 12500
|
| 174 |
+
low_freq_every: 5
|
| 175 |
+
config_files: # Must be in fairvit/eval/configs
|
| 176 |
+
high_freq: benchmark_high_frequency.yaml # More often
|
| 177 |
+
low_freq: benchmark_low_frequency.yaml # Less often
|
| 178 |
+
checkpointing:
|
| 179 |
+
period: 3750
|
| 180 |
+
max_to_keep: 3
|
| 181 |
+
keep_every: 99999999999999999 # Save a checkpoint every N iterations, regardless of max_to_keep and period
|
| 182 |
+
|
| 183 |
+
# Example of constant schedules with schedules v2
|
| 184 |
+
# # schedules:
|
| 185 |
+
# # lr:
|
| 186 |
+
# # start: 0.0
|
| 187 |
+
# # peak: 1e-3
|
| 188 |
+
# # end: 1e-6
|
| 189 |
+
# # warmup_epochs: 10
|
| 190 |
+
# # freeze_last_layer_epochs: 1
|
| 191 |
+
# # weight_decay:
|
| 192 |
+
# # start: 0.04
|
| 193 |
+
# # peak: 0.04
|
| 194 |
+
# # end: 0.04
|
| 195 |
+
# # warmup_epochs: 0
|
| 196 |
+
# # momentum:
|
| 197 |
+
# # start: 0.992
|
| 198 |
+
# # peak: 0.992
|
| 199 |
+
# # end: 0.992
|
| 200 |
+
# # warmup_epochs: 0
|
| 201 |
+
# # teacher_temp:
|
| 202 |
+
# # start: 0.04
|
| 203 |
+
# # peak: 0.07
|
| 204 |
+
# # end: 0.07
|
| 205 |
+
# # warmup_epochs: 30
|