ritianyu commited on
Commit
78bb21c
·
0 Parent(s):
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. InfiniDepth/__init__.py +1 -0
  2. InfiniDepth/model/__init__.py +9 -0
  3. InfiniDepth/model/block/__init__.py +1 -0
  4. InfiniDepth/model/block/common.py +43 -0
  5. InfiniDepth/model/block/config.py +5 -0
  6. InfiniDepth/model/block/convolution.py +229 -0
  7. InfiniDepth/model/block/implicit_decoder.py +179 -0
  8. InfiniDepth/model/block/pe.py +222 -0
  9. InfiniDepth/model/block/perceive_io.py +274 -0
  10. InfiniDepth/model/block/prompt_models/__init__.py +31 -0
  11. InfiniDepth/model/block/prompt_models/__pycache__/__init__.cpython-310.pyc +0 -0
  12. InfiniDepth/model/block/prompt_models/__pycache__/__init__.cpython-311.pyc +0 -0
  13. InfiniDepth/model/block/prompt_models/__pycache__/crossattn.cpython-310.pyc +0 -0
  14. InfiniDepth/model/block/prompt_models/__pycache__/diffattn.cpython-310.pyc +0 -0
  15. InfiniDepth/model/block/prompt_models/__pycache__/rope.cpython-310.pyc +0 -0
  16. InfiniDepth/model/block/prompt_models/__pycache__/rope.cpython-311.pyc +0 -0
  17. InfiniDepth/model/block/prompt_models/__pycache__/sam.cpython-310.pyc +0 -0
  18. InfiniDepth/model/block/prompt_models/__pycache__/sam.cpython-311.pyc +0 -0
  19. InfiniDepth/model/block/prompt_models/__pycache__/selfattn.cpython-310.pyc +0 -0
  20. InfiniDepth/model/block/prompt_models/__pycache__/selfattn.cpython-311.pyc +0 -0
  21. InfiniDepth/model/block/prompt_models/crossattn.py +164 -0
  22. InfiniDepth/model/block/prompt_models/rope.py +215 -0
  23. InfiniDepth/model/block/prompt_models/sam.py +126 -0
  24. InfiniDepth/model/block/prompt_models/selfattn.py +289 -0
  25. InfiniDepth/model/block/prompt_models/utils/__init__.py +1 -0
  26. InfiniDepth/model/block/prompt_models/utils/__pycache__/__init__.cpython-310.pyc +0 -0
  27. InfiniDepth/model/block/prompt_models/utils/__pycache__/__init__.cpython-311.pyc +0 -0
  28. InfiniDepth/model/block/prompt_models/utils/__pycache__/pe_utils.cpython-310.pyc +0 -0
  29. InfiniDepth/model/block/prompt_models/utils/__pycache__/pe_utils.cpython-311.pyc +0 -0
  30. InfiniDepth/model/block/prompt_models/utils/__pycache__/transformer.cpython-310.pyc +0 -0
  31. InfiniDepth/model/block/prompt_models/utils/__pycache__/transformer.cpython-311.pyc +0 -0
  32. InfiniDepth/model/block/prompt_models/utils/pe_utils.py +72 -0
  33. InfiniDepth/model/block/prompt_models/utils/transformer.py +250 -0
  34. InfiniDepth/model/block/rope.py +69 -0
  35. InfiniDepth/model/block/torchhub/README.md +3 -0
  36. InfiniDepth/model/block/torchhub/dinov3/.docstr.yaml +6 -0
  37. InfiniDepth/model/block/torchhub/dinov3/.github/workflows/lint.yaml +47 -0
  38. InfiniDepth/model/block/torchhub/dinov3/.gitignore +18 -0
  39. InfiniDepth/model/block/torchhub/dinov3/CODE_OF_CONDUCT.md +80 -0
  40. InfiniDepth/model/block/torchhub/dinov3/CONTRIBUTING.md +31 -0
  41. InfiniDepth/model/block/torchhub/dinov3/LICENSE.md +66 -0
  42. InfiniDepth/model/block/torchhub/dinov3/MODEL_CARD.md +432 -0
  43. InfiniDepth/model/block/torchhub/dinov3/README.md +734 -0
  44. InfiniDepth/model/block/torchhub/dinov3/conda.yaml +23 -0
  45. InfiniDepth/model/block/torchhub/dinov3/dinov3/__init__.py +6 -0
  46. InfiniDepth/model/block/torchhub/dinov3/dinov3/checkpointer/__init__.py +18 -0
  47. InfiniDepth/model/block/torchhub/dinov3/dinov3/checkpointer/checkpointer.py +349 -0
  48. InfiniDepth/model/block/torchhub/dinov3/dinov3/configs/__init__.py +16 -0
  49. InfiniDepth/model/block/torchhub/dinov3/dinov3/configs/config.py +222 -0
  50. InfiniDepth/model/block/torchhub/dinov3/dinov3/configs/ssl_default_config.yaml +205 -0
InfiniDepth/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """InfiniDepth package."""
InfiniDepth/model/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from .registry import MODEL_REGISTRY, register_model
2
+ from .model import InfiniDepth, InfiniDepth_DC
3
+
4
+ __all__ = [
5
+ "MODEL_REGISTRY",
6
+ "register_model",
7
+ "InfiniDepth",
8
+ "InfiniDepth_DC",
9
+ ]
InfiniDepth/model/block/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Core building blocks for InfiniDepth models."""
InfiniDepth/model/block/common.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+ from typing import Type
11
+
12
+
13
+ class MLPBlock(nn.Module):
14
+ def __init__(
15
+ self,
16
+ embedding_dim: int,
17
+ mlp_dim: int,
18
+ act: Type[nn.Module] = nn.GELU,
19
+ ) -> None:
20
+ super().__init__()
21
+ self.lin1 = nn.Linear(embedding_dim, mlp_dim)
22
+ self.lin2 = nn.Linear(mlp_dim, embedding_dim)
23
+ self.act = act()
24
+
25
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
26
+ return self.lin2(self.act(self.lin1(x)))
27
+
28
+
29
+ # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
30
+ # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa
31
+ class LayerNorm2d(nn.Module):
32
+ def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
33
+ super().__init__()
34
+ self.weight = nn.Parameter(torch.ones(num_channels))
35
+ self.bias = nn.Parameter(torch.zeros(num_channels))
36
+ self.eps = eps
37
+
38
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
39
+ u = x.mean(1, keepdim=True)
40
+ s = (x - u).pow(2).mean(1, keepdim=True)
41
+ x = (x - u) / torch.sqrt(s + self.eps)
42
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
43
+ return x
InfiniDepth/model/block/config.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ dinov3_model_configs = {
2
+ "vitl16":{'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024], 'layer_idxs': [4, 11, 17, 23]},
3
+ "vith16plus": {'encoder': 'vith', 'features': 384, 'out_channels': [1280, 1280, 1280, 1280], 'layer_idxs': [7, 15, 23, 31]},
4
+ }
5
+
InfiniDepth/model/block/convolution.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from functools import partial
11
+ from typing import Callable
12
+ import collections
13
+ from torch import Tensor
14
+ from itertools import repeat
15
+
16
+
17
+ def bilinear_sampler(input, coords, align_corners=True, padding_mode="border"):
18
+ r"""Sample a tensor using bilinear interpolation
19
+
20
+ `bilinear_sampler(input, coords)` samples a tensor :attr:`input` at
21
+ coordinates :attr:`coords` using bilinear interpolation. It is the same
22
+ as `torch.nn.functional.grid_sample()` but with a different coordinate
23
+ convention.
24
+
25
+ The input tensor is assumed to be of shape :math:`(B, C, H, W)`, where
26
+ :math:`B` is the batch size, :math:`C` is the number of channels,
27
+ :math:`H` is the height of the image, and :math:`W` is the width of the
28
+ image. The tensor :attr:`coords` of shape :math:`(B, H_o, W_o, 2)` is
29
+ interpreted as an array of 2D point coordinates :math:`(x_i,y_i)`.
30
+
31
+ Alternatively, the input tensor can be of size :math:`(B, C, T, H, W)`,
32
+ in which case sample points are triplets :math:`(t_i,x_i,y_i)`. Note
33
+ that in this case the order of the components is slightly different
34
+ from `grid_sample()`, which would expect :math:`(x_i,y_i,t_i)`.
35
+
36
+ If `align_corners` is `True`, the coordinate :math:`x` is assumed to be
37
+ in the range :math:`[0,W-1]`, with 0 corresponding to the center of the
38
+ left-most image pixel :math:`W-1` to the center of the right-most
39
+ pixel.
40
+
41
+ If `align_corners` is `False`, the coordinate :math:`x` is assumed to
42
+ be in the range :math:`[0,W]`, with 0 corresponding to the left edge of
43
+ the left-most pixel :math:`W` to the right edge of the right-most
44
+ pixel.
45
+
46
+ Similar conventions apply to the :math:`y` for the range
47
+ :math:`[0,H-1]` and :math:`[0,H]` and to :math:`t` for the range
48
+ :math:`[0,T-1]` and :math:`[0,T]`.
49
+
50
+ Args:
51
+ input (Tensor): batch of input images.
52
+ coords (Tensor): batch of coordinates.
53
+ align_corners (bool, optional): Coordinate convention. Defaults to `True`.
54
+ padding_mode (str, optional): Padding mode. Defaults to `"border"`.
55
+
56
+ Returns:
57
+ Tensor: sampled points.
58
+ """
59
+
60
+ sizes = input.shape[2:]
61
+
62
+ assert len(sizes) in [2, 3]
63
+
64
+ if len(sizes) == 3:
65
+ # t x y -> x y t to match dimensions T H W in grid_sample
66
+ coords = coords[..., [1, 2, 0]]
67
+
68
+ if align_corners:
69
+ coords = coords * torch.tensor(
70
+ [2 / max(size - 1, 1) for size in reversed(sizes)], device=coords.device
71
+ )
72
+ else:
73
+ coords = coords * torch.tensor(
74
+ [2 / size for size in reversed(sizes)], device=coords.device
75
+ )
76
+
77
+ coords -= 1
78
+
79
+ return F.grid_sample(
80
+ input, coords, align_corners=align_corners, padding_mode=padding_mode
81
+ )
82
+
83
+
84
+ def round_to_multiple_of_4(n):
85
+ return round(n / 4) * 4
86
+
87
+
88
+
89
+ class ResidualBlock(nn.Module):
90
+ def __init__(self, in_planes, planes, norm_fn="group", stride=1):
91
+ super(ResidualBlock, self).__init__()
92
+
93
+ self.conv1 = nn.Conv2d(
94
+ in_planes,
95
+ planes,
96
+ kernel_size=3,
97
+ padding=1,
98
+ stride=stride,
99
+ padding_mode="zeros",
100
+ )
101
+ self.conv2 = nn.Conv2d(
102
+ planes, planes, kernel_size=3, padding=1, padding_mode="zeros"
103
+ )
104
+ self.relu = nn.ReLU(inplace=True)
105
+
106
+ num_groups = planes // 8
107
+
108
+ if norm_fn == "group":
109
+ self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
110
+ self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
111
+ if not stride == 1:
112
+ self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
113
+
114
+ elif norm_fn == "batch":
115
+ self.norm1 = nn.BatchNorm2d(planes)
116
+ self.norm2 = nn.BatchNorm2d(planes)
117
+ if not stride == 1:
118
+ self.norm3 = nn.BatchNorm2d(planes)
119
+
120
+ elif norm_fn == "instance":
121
+ self.norm1 = nn.InstanceNorm2d(planes)
122
+ self.norm2 = nn.InstanceNorm2d(planes)
123
+ if not stride == 1:
124
+ self.norm3 = nn.InstanceNorm2d(planes)
125
+
126
+ elif norm_fn == "none":
127
+ self.norm1 = nn.Sequential()
128
+ self.norm2 = nn.Sequential()
129
+ if not stride == 1:
130
+ self.norm3 = nn.Sequential()
131
+
132
+ if stride == 1:
133
+ self.downsample = None
134
+
135
+ else:
136
+ self.downsample = nn.Sequential(
137
+ nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3
138
+ )
139
+
140
+ def forward(self, x):
141
+ y = x
142
+ y = self.relu(self.norm1(self.conv1(y)))
143
+ y = self.relu(self.norm2(self.conv2(y)))
144
+
145
+ if self.downsample is not None:
146
+ x = self.downsample(x)
147
+
148
+ return self.relu(x + y)
149
+
150
+
151
+ class BasicEncoder(nn.Module):
152
+ def __init__(self, input_dim=3, output_dim=128, stride=4):
153
+ super(BasicEncoder, self).__init__()
154
+ self.stride = stride
155
+ self.norm_fn = "instance"
156
+ self.in_planes = output_dim // 2
157
+ self.norm1 = nn.InstanceNorm2d(self.in_planes)
158
+ self.norm2 = nn.InstanceNorm2d(output_dim * 2)
159
+
160
+ self.conv1 = nn.Conv2d(
161
+ input_dim,
162
+ self.in_planes,
163
+ kernel_size=7,
164
+ stride=2,
165
+ padding=3,
166
+ padding_mode="zeros",
167
+ )
168
+ self.relu1 = nn.ReLU(inplace=True)
169
+ self.layer1 = self._make_layer(output_dim // 2, stride=1)
170
+ self.layer2 = self._make_layer(output_dim // 4 * 3, stride=2)
171
+ self.layer3 = self._make_layer(output_dim, stride=2)
172
+ self.layer4 = self._make_layer(output_dim, stride=2)
173
+
174
+ self.conv2 = nn.Conv2d(
175
+ output_dim * 3 + output_dim // 4,
176
+ output_dim * 2,
177
+ kernel_size=3,
178
+ padding=1,
179
+ padding_mode="zeros",
180
+ )
181
+ self.relu2 = nn.ReLU(inplace=True)
182
+ self.conv3 = nn.Conv2d(output_dim * 2, output_dim, kernel_size=1)
183
+ for m in self.modules():
184
+ if isinstance(m, nn.Conv2d):
185
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
186
+ elif isinstance(m, (nn.InstanceNorm2d)):
187
+ if m.weight is not None:
188
+ nn.init.constant_(m.weight, 1)
189
+ if m.bias is not None:
190
+ nn.init.constant_(m.bias, 0)
191
+
192
+ def _make_layer(self, dim, stride=1):
193
+ layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
194
+ layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
195
+ layers = (layer1, layer2)
196
+
197
+ self.in_planes = dim
198
+ return nn.Sequential(*layers)
199
+
200
+ def forward(self, x):
201
+ _, _, H, W = x.shape
202
+
203
+ x = self.conv1(x)
204
+ x = self.norm1(x)
205
+ x = self.relu1(x)
206
+
207
+ a = self.layer1(x)
208
+ b = self.layer2(a)
209
+ c = self.layer3(b)
210
+ d = self.layer4(c)
211
+
212
+ def _bilinear_intepolate(x):
213
+ return F.interpolate(
214
+ x,
215
+ (H // self.stride, W // self.stride),
216
+ mode="bilinear",
217
+ align_corners=True,
218
+ )
219
+
220
+ a = _bilinear_intepolate(a)
221
+ b = _bilinear_intepolate(b)
222
+ c = _bilinear_intepolate(c)
223
+ d = _bilinear_intepolate(d)
224
+
225
+ x = self.conv2(torch.cat([a, b, c, d], dim=1))
226
+ x = self.norm2(x)
227
+ x = self.relu2(x)
228
+ x = self.conv3(x)
229
+ return x
InfiniDepth/model/block/implicit_decoder.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import enum
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import sys
7
+ from grpc import insecure_channel
8
+ from sympy import use
9
+ from pathlib import Path
10
+
11
+
12
+ def exists(val):
13
+ return val is not None
14
+
15
+
16
+ def default(val, d):
17
+ return val if exists(val) else d
18
+
19
+
20
+ class MLP(nn.Module):
21
+ def __init__(self, in_dim, out_dim, hidden_list, output_act='elu'):
22
+ super().__init__()
23
+ layers = []
24
+ lastv = in_dim
25
+ for hidden in hidden_list:
26
+ layers += [nn.Linear(lastv, hidden), nn.ReLU()]
27
+ lastv = hidden
28
+
29
+ if out_dim is not None:
30
+ layers.append(nn.Linear(lastv, out_dim))
31
+ act = {
32
+ "sigmoid": nn.Sigmoid(),
33
+ "relu": nn.ReLU(),
34
+ "elu": nn.ELU(),
35
+ }.get(output_act, nn.Identity())
36
+ layers.append(act)
37
+
38
+ self.layers = nn.Sequential(*layers)
39
+
40
+ def forward(self, x):
41
+ return self.layers(x)
42
+
43
+
44
+ class ImplicitHead(nn.Module):
45
+ """
46
+ Implicit head that fuses DINOv3 semantic features and BasicEncoder low-level features.
47
+
48
+ Args:
49
+ hidden_dim: DINOv3 feature dimension (e.g., 1024)
50
+ basic_dim: BasicEncoder feature dimension (e.g., 128)
51
+ fusion_type: Feature fusion strategy
52
+ - "concat": Simple concatenation
53
+ - "cross_attn": Cross-attention between features
54
+ - "gated": Gated fusion with learnable weights
55
+ out_dim: Output dimension (1 for depth)
56
+ hidden_list: MLP hidden layer dimensions
57
+ """
58
+ def __init__(
59
+ self,
60
+ hidden_dim, # 1024 for DINOv3
61
+ basic_dim=128, # BasicEncoder output dim
62
+ fusion_type="gated", # concat, gated
63
+ out_dim=1,
64
+ hidden_list=[1024, 256, 32],
65
+ ):
66
+
67
+ super().__init__()
68
+ self.hidden_dim = hidden_dim
69
+ self.basic_dim = basic_dim
70
+ self.fusion_type = fusion_type
71
+
72
+ # Determine input dimension based on fusion type
73
+ if fusion_type == "concat":
74
+ # Simple concatenation
75
+ in_channels = hidden_dim + basic_dim
76
+ elif fusion_type == "gated":
77
+ # Gated fusion with learnable weights
78
+ self.gate_proj = nn.Linear(basic_dim, hidden_dim)
79
+ self.gate = nn.Sequential(
80
+ nn.Linear(hidden_dim * 2, hidden_dim),
81
+ nn.Sigmoid()
82
+ )
83
+ in_channels = hidden_dim
84
+ else:
85
+ raise ValueError(f"Unknown fusion_type: {fusion_type}")
86
+
87
+ self.out_layer = MLP(
88
+ in_dim=in_channels,
89
+ out_dim=out_dim,
90
+ hidden_list=hidden_list,
91
+ output_act='elu'
92
+ )
93
+
94
+ def _encode_feat(self, features, patch_h, patch_w):
95
+ """Extract DINOv3 feature map."""
96
+ x = features[-1][0]
97
+ out_feat = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w))
98
+ return out_feat
99
+
100
+ def _decode_dpt(self, feat, basic_feat, coord):
101
+ """
102
+ Query features at given coordinates and fuse them.
103
+
104
+ Args:
105
+ feat: DINOv3 feature map [B, hidden_dim, H_dino, W_dino]
106
+ basic_feat: BasicEncoder feature map [B, basic_dim, H_basic, W_basic]
107
+ coord: Query coordinates [B, N, 2] in range [-1, 1]
108
+
109
+ Returns:
110
+ pred: Predicted depth [B, N, 1]
111
+ """
112
+ coord_ = coord.clone()
113
+ coord_.clamp_(-1 + 1e-6, 1 - 1e-6)
114
+
115
+ # Sample DINOv3 features at query coordinates
116
+ q_feat_dino = F.grid_sample(
117
+ feat, coord_.flip(-1).unsqueeze(1),
118
+ mode='bilinear', align_corners=False
119
+ )[:, :, 0, :].permute(0, 2, 1) # [B, N, hidden_dim]
120
+
121
+ # Sample BasicEncoder features at query coordinates (if available)
122
+ if basic_feat is not None:
123
+ q_feat_basic = F.grid_sample(
124
+ basic_feat, coord_.flip(-1).unsqueeze(1),
125
+ mode='bilinear', align_corners=False
126
+ )[:, :, 0, :].permute(0, 2, 1) # [B, N, basic_dim]
127
+
128
+ # Fuse features based on fusion type
129
+ q_feat_fused = self._fuse_features(q_feat_dino, q_feat_basic)
130
+ else:
131
+ # If no basic features, use only DINOv3
132
+ q_feat_fused = q_feat_dino
133
+
134
+ # Predict depth
135
+ pred = self.out_layer(q_feat_fused)
136
+ return pred
137
+
138
+ def _fuse_features(self, feat_dino, feat_basic):
139
+ """
140
+ Fuse DINOv3 and BasicEncoder features.
141
+
142
+ Args:
143
+ feat_dino: [B, N, hidden_dim]
144
+ feat_basic: [B, N, basic_dim]
145
+
146
+ Returns:
147
+ fused_feat: [B, N, fused_dim]
148
+ """
149
+ if self.fusion_type == "concat":
150
+ # Simple concatenation
151
+ return torch.cat([feat_dino, feat_basic], dim=-1)
152
+
153
+ elif self.fusion_type == "gated":
154
+ # Gated fusion with learnable weights
155
+ feat_basic_proj = self.gate_proj(feat_basic) # [B, N, hidden_dim]
156
+ gate_input = torch.cat([feat_dino, feat_basic_proj], dim=-1)
157
+ gate_weights = self.gate(gate_input) # [B, N, hidden_dim]
158
+ return gate_weights * feat_dino + (1 - gate_weights) * feat_basic_proj
159
+
160
+ def forward(self, features, basic_feat, patch_h, patch_w, coords):
161
+ """
162
+ Forward pass.
163
+
164
+ Args:
165
+ features: DINOv3 features from backbone
166
+ basic_feat: BasicEncoder features [B, basic_dim, H/4, W/4]
167
+ patch_h, patch_w: DINOv3 feature map spatial size
168
+ coords: Query coordinates [B, N, 2]
169
+
170
+ Returns:
171
+ dpt_pred: Predicted depth [B, N, 1]
172
+ """
173
+ # Extract DINOv3 feature map
174
+ feat = self._encode_feat(features, patch_h, patch_w) # [B, hidden_dim, H/14, W/14]
175
+
176
+ # Query and fuse features at coordinates
177
+ dpt_pred = self._decode_dpt(feat, basic_feat, coords)
178
+
179
+ return dpt_pred
InfiniDepth/model/block/pe.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import math
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from typing import Any, Optional, Tuple, Dict
9
+
10
+ if torch.cuda.is_available():
11
+ acc_dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16
12
+ else:
13
+ acc_dtype = torch.float16
14
+
15
+ POS_EMB_REGISTRY = {}
16
+
17
+ def register_pos_emb(name):
18
+ def decorator(cls):
19
+ POS_EMB_REGISTRY[name.lower()] = cls
20
+ return cls
21
+ return decorator
22
+
23
+
24
+ @register_pos_emb("dct")
25
+ class DctPositionEmbedding(nn.Module):
26
+ """
27
+ Only supports 2D separable DCT encoding for query coordinates coords: [B, N, 2]:
28
+ Phi(x,y)[fx,fy] = cos(pi * fx * x) * cos(pi * fy * y) * 1/(1+fx*fy)
29
+ Convention: coords should be in the range [0, 1].
30
+ """
31
+ def __init__(self, max_freqs: int = 8):
32
+ super().__init__()
33
+ self.max_freqs = max_freqs
34
+
35
+ freqs = torch.arange(max_freqs).float() # [F] -> 0..F-1
36
+ fx = freqs.view(-1, 1) # [F,1]
37
+ fy = freqs.view(1, -1) # [1,F]
38
+ coeffs = (1.0 + fx * fy) ** -1 # [F,F]
39
+
40
+ self.register_buffer("_freqs_1d", freqs, persistent=False)
41
+ self.register_buffer("_coeffs_2d", coeffs, persistent=False)
42
+
43
+ def forward(self, coords: torch.Tensor) -> torch.Tensor:
44
+ """
45
+ coords: [B, N, 2], value range should be [0,1]
46
+ return: [B, N, F^2]
47
+ """
48
+ assert coords.dim() == 3 and coords.size(-1) == 2, "coords must be [B, N, 2]"
49
+ B, N, _ = coords.shape
50
+ device, dtype = coords.device, coords.dtype
51
+
52
+ freqs = self._freqs_1d.to(device=device, dtype=dtype) # [F]
53
+ coeffs = self._coeffs_2d.to(device=device, dtype=dtype) # [F,F]
54
+ F = freqs.numel() # frequency dimension = max_freqs
55
+
56
+ x = coords[..., 0:1] # [B,N,1]
57
+ y = coords[..., 1:2] # [B,N,1]
58
+ dct_x = torch.cos(math.pi * x * freqs.view(1, 1, F)) # [B,N,F]
59
+ dct_y = torch.cos(math.pi * y * freqs.view(1, 1, F)) # [B,N,F]
60
+
61
+ out = dct_x.unsqueeze(-1) * dct_y.unsqueeze(-2) # [B,N,F,F]
62
+ out = out * coeffs.view(1, 1, F, F) # [B,N,F,F]
63
+ dct_emb = out.reshape(B, N, F * F) # [B,N,F^2]
64
+
65
+ return dct_emb # [B,N,F^2]
66
+
67
+
68
+ @register_pos_emb("random")
69
+ class RandomPositionEmbedding(nn.Module):
70
+ """
71
+ Positional encoding using random spatial frequencies.
72
+ """
73
+
74
+ patch_size = 14
75
+
76
+ def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None, image_pe_method: str = "patch") -> None:
77
+ super().__init__()
78
+ if scale is None or scale <= 0.0:
79
+ scale = 1.0
80
+ self.register_buffer(
81
+ "positional_encoding_gaussian_matrix",
82
+ scale * torch.randn((2, num_pos_feats)),
83
+ )
84
+
85
+ self.image_pe_method = image_pe_method
86
+ if self.image_pe_method == "image":
87
+ # self.patch_embed = nn.Conv2d(num_pos_feats*2, num_pos_feats*2, kernel_size=self.patch_size, stride=self.patch_size)
88
+ self.patch_embed = nn.Sequential(
89
+ nn.Conv2d(num_pos_feats * 2, num_pos_feats // 2, kernel_size=2, stride=2),
90
+ nn.ReLU(),
91
+ nn.Conv2d(
92
+ num_pos_feats // 2,
93
+ num_pos_feats * 2,
94
+ kernel_size=self.patch_size // 2,
95
+ stride=self.patch_size // 2,
96
+ ),
97
+ )
98
+
99
+ def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
100
+ """Positionally encode points that are normalized to [0,1]."""
101
+ # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
102
+ coords = 2 * coords - 1 # [0,1] --> [-1,1], equivalent to align_corners=False after this transform
103
+ coords = coords @ self.positional_encoding_gaussian_matrix
104
+ coords = 2 * np.pi * coords
105
+ # outputs d_1 x ... x d_n x C shape
106
+ return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
107
+
108
+ def forward_encoding(self, size: Tuple[int, int]) -> torch.Tensor:
109
+ """Generate positional encoding for a grid of the specified size."""
110
+ h, w = size
111
+ device: Any = self.positional_encoding_gaussian_matrix.device
112
+ grid = torch.ones((h, w), device=device, dtype=torch.float32)
113
+ y_embed = grid.cumsum(dim=0) - 0.5
114
+ x_embed = grid.cumsum(dim=1) - 0.5
115
+ y_embed = y_embed / h
116
+ x_embed = x_embed / w
117
+ pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1)) # HxWx2 -> HxWxC
118
+ return pe.permute(2, 0, 1) # C x H x W
119
+
120
+ def forward(self, size: Tuple[int, int]) -> torch.Tensor:
121
+ if self.image_pe_method == "patch":
122
+ return self.forward_encoding(size)
123
+ elif self.image_pe_method == "image":
124
+ pe_encoding = self.forward_encoding(size)
125
+ pe_encoding_high = self.forward_encoding((size[0] * self.patch_size, size[1] * self.patch_size))
126
+ return pe_encoding + self.patch_embed(pe_encoding_high[None])[0]
127
+
128
+ def forward_with_coords(self, coords_input: torch.Tensor, image_size: Tuple[int, int]) -> torch.Tensor:
129
+ """Positionally encode points that are not normalized to [0,1]."""
130
+ coords = coords_input.clone()
131
+ coords[:, :, 0] = coords[:, :, 0] / image_size[1]
132
+ coords[:, :, 1] = coords[:, :, 1] / image_size[0]
133
+ return self._pe_encoding(coords.to(torch.float)) # B x N x C
134
+
135
+
136
+ @register_pos_emb("rope")
137
+ class RoPEPositionEmbedding(nn.Module):
138
+ """2D Rotary Position Embedding with support for continuous coordinates.
139
+
140
+ For each coordinate p (can be float), directly compute θ = p * inv_freq, then derive cos/sin.
141
+ """
142
+ def __init__(self, frequency: float = 100.0, scaling_factor: float = 1.0):
143
+ super().__init__()
144
+ self.base_frequency = frequency
145
+ self.scaling_factor = scaling_factor
146
+ # Cache the inv_freq vector: key = feature_dim
147
+ self._inv_freq_cache: Dict[int, torch.Tensor] = {}
148
+
149
+ def _get_inv_freq(self, dim: int, device: torch.device, dtype: torch.dtype):
150
+ """
151
+ Computes frequency components for rotary embeddings.
152
+ Returns an inv_freq vector of length dim/2, in the form 1 / base_freq^(2i/d)
153
+ """
154
+ if dim not in self._inv_freq_cache:
155
+ # Use frequencies on even dimensions only
156
+ exponents = torch.arange(0, dim, 2, device=device).float() / dim
157
+ inv_freq = 1.0 / (self.base_frequency ** exponents)
158
+ self._inv_freq_cache[dim] = inv_freq.to(dtype)
159
+ return self._inv_freq_cache[dim]
160
+
161
+ @staticmethod
162
+ def _rotate_features(x: torch.Tensor) -> torch.Tensor:
163
+ """Rotation: split [u0, u1, u2, u3,...] into two halves and concatenate (-v, u)."""
164
+ D = x.shape[-1]
165
+ x1, x2 = x[..., : D//2], x[..., D//2 :]
166
+ return torch.cat((-x2, x1), dim=-1)
167
+
168
+ def _apply_1d_rope_continuous(
169
+ self,
170
+ x: torch.Tensor, # [B, N, d_half]
171
+ pos: torch.Tensor, # [B, N] floating-point coordinates
172
+ inv_freq: torch.Tensor # [d_half]
173
+ ) -> torch.Tensor:
174
+ # 1) Compute angles: [B, N, d_half] = outer(pos, inv_freq)
175
+ # pos.unsqueeze(-1): [B, N, 1], inv_freq.unsqueeze(0): [1, d_half]
176
+ angles = pos.unsqueeze(-1) * inv_freq.unsqueeze(0)
177
+ # 2) Duplicate to double dimension: [B, N, d_half*2]
178
+ angles = torch.cat([angles, angles], dim=-1)
179
+
180
+ # 3) Compute cos/sin and expand to [B, N, D]
181
+ cos = angles.cos()
182
+ sin = angles.sin()
183
+
184
+ # 4) Apply rotation
185
+ return x * cos + self._rotate_features(x) * sin
186
+
187
+ def forward(self, tokens: torch.Tensor, positions: torch.Tensor) -> torch.Tensor:
188
+ """
189
+ tokens: [B, N, dim]
190
+ positions: [B, N, 2] continuous coords: (y,x)
191
+ """
192
+ B, N, D = tokens.shape
193
+ assert D % 2 == 0, "Feature dimension must be even"
194
+
195
+ assert positions.shape == (B, N, 2), "positions must be [B, N, 2]"
196
+
197
+ # Allocate half of the features to each direction
198
+ d_half = D // 2
199
+
200
+ # Get the inv_freq vector
201
+ inv_freq = self._get_inv_freq(d_half, tokens.device, tokens.dtype) # [d_half]
202
+ # Split feature dimension into first and second halves
203
+ tok_v, tok_h = tokens[..., :d_half], tokens[..., d_half:]
204
+
205
+ # Apply RoPE separately on y and x directions, positions[0]--> y, positions[1]--> x
206
+ out_v = self._apply_1d_rope_continuous(tok_v, positions[..., 0], inv_freq)
207
+ out_h = self._apply_1d_rope_continuous(tok_h, positions[..., 1], inv_freq)
208
+
209
+ return torch.cat([out_v, out_h], dim=-1)
210
+
211
+
212
+ def build_pos_emb(pos_emb_type="nerf", **kwargs):
213
+ pos_emb_type = pos_emb_type.lower()
214
+ if pos_emb_type not in POS_EMB_REGISTRY:
215
+ raise ValueError(f"Unknown pos_emb_type: {pos_emb_type}")
216
+ return POS_EMB_REGISTRY[pos_emb_type](**kwargs)
217
+
218
+
219
+
220
+
221
+
222
+
InfiniDepth/model/block/perceive_io.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, Optional, Union
2
+ import torch
3
+ from einops import rearrange
4
+ from ...utils.logger import Log
5
+ from torch import Tensor, nn
6
+
7
+ try:
8
+ from xformers.ops import memory_efficient_attention, unbind
9
+
10
+ XFORMERS_AVAILABLE = True
11
+ except ImportError:
12
+ Log.warning("xFormers not available")
13
+ XFORMERS_AVAILABLE = False
14
+
15
+
16
+ class CrossAttention(nn.Module):
17
+ def __init__(
18
+ self,
19
+ dim: int,
20
+ context_dim: int,
21
+ num_heads: int = 8,
22
+ qkv_bias: bool = False,
23
+ proj_bias: bool = True,
24
+ attn_drop: float = 0.0,
25
+ proj_drop: float = 0.0,
26
+ pe: str = "normal",
27
+ ) -> None:
28
+ super().__init__()
29
+ self.num_heads = num_heads
30
+ head_dim = dim // num_heads
31
+ self.scale = head_dim**-0.5
32
+
33
+ self.qkv = nn.Linear(dim, dim, bias=qkv_bias)
34
+ self.qkv_context = nn.Linear(context_dim, context_dim * 2, bias=qkv_bias)
35
+ self.attn_drop = nn.Dropout(attn_drop)
36
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
37
+ self.proj_drop = nn.Dropout(proj_drop)
38
+ self.pe = pe
39
+ if self.pe == "qk":
40
+ self.norm1 = nn.LayerNorm(dim // num_heads)
41
+ self.norm2 = nn.LayerNorm(dim // num_heads)
42
+
43
+ def forward(self, x: Tensor, context: Tensor) -> Tensor:
44
+ # x is the query tensor, context is the key/value tensor
45
+ B, N, C = x.shape
46
+ _, M, _ = context.shape
47
+
48
+ qkv_x = self.qkv(x).reshape(B, N, 1, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
49
+ q_x = qkv_x[0] * self.scale
50
+
51
+ qkv_context = (
52
+ self.qkv_context(context).reshape(B, M, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
53
+ )
54
+ k_context, v_context = qkv_context[0], qkv_context[0]
55
+
56
+ # Cross-attention: query from x and key/value from context
57
+ attn = q_x @ k_context.transpose(-2, -1)
58
+ attn = attn.softmax(dim=-1)
59
+ attn = self.attn_drop(attn)
60
+
61
+ x = (attn @ v_context).transpose(1, 2).reshape(B, N, C)
62
+ x = self.proj(x)
63
+ x = self.proj_drop(x)
64
+ return x
65
+
66
+
67
+ class MemEffCrossAttention(CrossAttention):
68
+ def forward(
69
+ self, x: Tensor, context: Tensor, x_pe: Tensor = None, context_pe: Tensor = None, attn_bias=None
70
+ ) -> Tensor:
71
+ if not XFORMERS_AVAILABLE:
72
+ assert attn_bias is None, "xFormers is required for nested tensors usage"
73
+ return super().forward(x, context)
74
+
75
+ B, N, C = x.shape
76
+ _, M, C_context = context.shape
77
+
78
+ qkv_x = self.qkv(x).reshape(B, N, 1, self.num_heads, C // self.num_heads)
79
+ (q_x,) = unbind(qkv_x, 2)
80
+
81
+ qkv_context = self.qkv_context(context).reshape(B, M, 2, self.num_heads, C_context // self.num_heads)
82
+ k_context, v_context = unbind(qkv_context, 2)
83
+
84
+ if self.pe == "qk":
85
+ q_x = self.norm1(q_x + rearrange(x_pe, "b n (m c) -> b n m c", m=self.num_heads))
86
+ k_context = self.norm2(k_context + rearrange(context_pe, "b n (m c) -> b n m c", m=self.num_heads))
87
+ elif self.pe == "apply":
88
+ pass
89
+ # Memory-efficient cross-attention
90
+ x = memory_efficient_attention(
91
+ q_x.to(dtype=v_context.dtype), k_context.to(dtype=v_context.dtype), v_context, attn_bias=attn_bias
92
+ )
93
+ x = x.reshape([B, N, C])
94
+
95
+ x = self.proj(x)
96
+ x = self.proj_drop(x)
97
+ return x
98
+
99
+
100
+ # class Attention(nn.Module):
101
+ # def __init__(
102
+ # self,
103
+ # dim: int,
104
+ # num_heads: int = 8,
105
+ # qkv_bias: bool = False,
106
+ # proj_bias: bool = True,
107
+ # attn_drop: float = 0.0,
108
+ # proj_drop: float = 0.0,
109
+ # ) -> None:
110
+ # super().__init__()
111
+ # self.num_heads = num_heads
112
+ # head_dim = dim // num_heads
113
+ # self.scale = head_dim**-0.5
114
+
115
+ # self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
116
+ # self.attn_drop = nn.Dropout(attn_drop)
117
+ # self.proj = nn.Linear(dim, dim, bias=proj_bias)
118
+ # self.proj_drop = nn.Dropout(proj_drop)
119
+
120
+ # def forward(self, x: Tensor) -> Tensor:
121
+ # B, N, C = x.shape
122
+ # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
123
+
124
+ # q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
125
+ # attn = q @ k.transpose(-2, -1)
126
+
127
+ # attn = attn.softmax(dim=-1)
128
+ # attn = self.attn_drop(attn)
129
+
130
+ # x = (attn @ v).transpose(1, 2).reshape(B, N, C)
131
+ # x = self.proj(x)
132
+ # x = self.proj_drop(x)
133
+ # return x
134
+
135
+
136
+ class Attention(nn.Module):
137
+ def __init__(
138
+ self,
139
+ dim: int,
140
+ num_heads: int = 8,
141
+ qkv_bias: bool = False,
142
+ proj_bias: bool = True,
143
+ attn_drop: float = 0.0,
144
+ proj_drop: float = 0.0,
145
+ pe: str = "normal",
146
+ ) -> None:
147
+ super().__init__()
148
+ self.num_heads = num_heads
149
+ head_dim = dim // num_heads
150
+ self.scale = head_dim**-0.5
151
+
152
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
153
+ self.attn_drop = nn.Dropout(attn_drop)
154
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
155
+ self.proj_drop = nn.Dropout(proj_drop)
156
+ self.pe = pe
157
+ if self.pe == "qk":
158
+ self.norm1 = nn.LayerNorm(dim // num_heads)
159
+ self.norm2 = nn.LayerNorm(dim // num_heads)
160
+
161
+ def forward(self, x: Tensor, x_pe: Tensor = None) -> Tensor:
162
+ B, N, C = x.shape
163
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
164
+
165
+ q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
166
+ if self.pe == "qk":
167
+ q = self.norm1(q + x_pe)
168
+ k = self.norm2(k + x_pe)
169
+ elif self.pe == "apply":
170
+ pass
171
+ attn = q @ k.transpose(-2, -1)
172
+
173
+ attn = attn.softmax(dim=-1)
174
+ attn = self.attn_drop(attn)
175
+
176
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
177
+ x = self.proj(x)
178
+ x = self.proj_drop(x)
179
+ return x
180
+
181
+
182
+ class MemEffAttention(Attention):
183
+ def forward(
184
+ self,
185
+ x: Tensor,
186
+ x_pe: Tensor = None,
187
+ attn_bias=None,
188
+ ) -> Tensor:
189
+ if not XFORMERS_AVAILABLE:
190
+ assert attn_bias is None, "xFormers is required for nested tensors usage"
191
+ return super().forward(x)
192
+
193
+ B, N, C = x.shape
194
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
195
+ q, k, v = unbind(qkv, 2)
196
+ if self.pe == "qk":
197
+ q = self.norm1(q + rearrange(x_pe, "b n (m c) -> b n m c", m=self.num_heads))
198
+ k = self.norm2(k + rearrange(x_pe, "b n (m c) -> b n m c", m=self.num_heads))
199
+ elif self.pe == "apply":
200
+ pass
201
+ # this is important
202
+ # as q, k after norm1/norm2 have different dtype
203
+ # which will cause error in memory_efficient_attention
204
+ x = memory_efficient_attention(q.to(dtype=v.dtype), k.to(dtype=v.dtype), v, attn_bias=attn_bias)
205
+ x = x.reshape([B, N, C])
206
+
207
+ x = self.proj(x)
208
+ x = self.proj_drop(x)
209
+ return x
210
+
211
+
212
+ class Mlp(nn.Module):
213
+ def __init__(
214
+ self,
215
+ in_features: int,
216
+ hidden_features: Optional[int] = None,
217
+ out_features: Optional[int] = None,
218
+ act_layer: Callable[..., nn.Module] = nn.GELU,
219
+ drop: float = 0.0,
220
+ bias: bool = True,
221
+ ) -> None:
222
+ super().__init__()
223
+ out_features = out_features or in_features
224
+ hidden_features = hidden_features or in_features
225
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
226
+ self.act = act_layer()
227
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
228
+ self.drop = nn.Dropout(drop)
229
+
230
+ def forward(self, x: Tensor) -> Tensor:
231
+ x = self.fc1(x)
232
+ x = self.act(x)
233
+ x = self.drop(x)
234
+ x = self.fc2(x)
235
+ x = self.drop(x)
236
+ return x
237
+
238
+
239
+ class LayerScale(nn.Module):
240
+ def __init__(
241
+ self,
242
+ dim: int,
243
+ init_values: Union[float, Tensor] = 1e-5,
244
+ inplace: bool = False,
245
+ ) -> None:
246
+ super().__init__()
247
+ self.inplace = inplace
248
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
249
+
250
+ def forward(self, x: Tensor) -> Tensor:
251
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
252
+
253
+
254
+ def drop_path(x, drop_prob: float = 0.0, training: bool = False):
255
+ if drop_prob == 0.0 or not training:
256
+ return x
257
+ keep_prob = 1 - drop_prob
258
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
259
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
260
+ if keep_prob > 0.0:
261
+ random_tensor.div_(keep_prob)
262
+ output = x * random_tensor
263
+ return output
264
+
265
+
266
+ class DropPath(nn.Module):
267
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
268
+
269
+ def __init__(self, drop_prob=None):
270
+ super().__init__()
271
+ self.drop_prob = drop_prob
272
+
273
+ def forward(self, x):
274
+ return drop_path(x, self.drop_prob, self.training)
InfiniDepth/model/block/prompt_models/__init__.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from .sam import SAMPromptModel
3
+ from .selfattn import SelfAttnPromptModel
4
+
5
+ __all__ = [
6
+ "GeneralPromptModel",
7
+ "SelfAttnPromptModel",
8
+ "SAMPromptModel",
9
+ ]
10
+
11
+
12
+ class GeneralPromptModel(nn.Module):
13
+ def __init__(self, prompt_stage=[3], **kwargs):
14
+ super().__init__()
15
+ self.prompt_stage = prompt_stage
16
+ self.prompt_idmap = {i: idx for idx, i in enumerate(self.prompt_stage)}
17
+ block = kwargs.get("block")
18
+ self.prompt_model = nn.ModuleList([block for _ in range(len(self.prompt_stage))])
19
+
20
+ def forward(self, features, prompt_depth, prompt_mask, patch_h, patch_w):
21
+ for i in range(len(features)):
22
+ if i not in self.prompt_stage: # prompt_stage = [3]
23
+ continue
24
+ features[i][0] = self.prompt_model[self.prompt_idmap[i]](
25
+ features[i][0],
26
+ prompt_depth,
27
+ prompt_mask,
28
+ patch_h,
29
+ patch_w,
30
+ )
31
+ return features
InfiniDepth/model/block/prompt_models/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (1.49 kB). View file
 
InfiniDepth/model/block/prompt_models/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (2.44 kB). View file
 
InfiniDepth/model/block/prompt_models/__pycache__/crossattn.cpython-310.pyc ADDED
Binary file (5.79 kB). View file
 
InfiniDepth/model/block/prompt_models/__pycache__/diffattn.cpython-310.pyc ADDED
Binary file (142 Bytes). View file
 
InfiniDepth/model/block/prompt_models/__pycache__/rope.cpython-310.pyc ADDED
Binary file (8.28 kB). View file
 
InfiniDepth/model/block/prompt_models/__pycache__/rope.cpython-311.pyc ADDED
Binary file (13 kB). View file
 
InfiniDepth/model/block/prompt_models/__pycache__/sam.cpython-310.pyc ADDED
Binary file (3.51 kB). View file
 
InfiniDepth/model/block/prompt_models/__pycache__/sam.cpython-311.pyc ADDED
Binary file (5.48 kB). View file
 
InfiniDepth/model/block/prompt_models/__pycache__/selfattn.cpython-310.pyc ADDED
Binary file (8.51 kB). View file
 
InfiniDepth/model/block/prompt_models/__pycache__/selfattn.cpython-311.pyc ADDED
Binary file (15.9 kB). View file
 
InfiniDepth/model/block/prompt_models/crossattn.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from typing import Tuple
4
+ from ..perceive_io import LayerScale, MemEffCrossAttention, Mlp
5
+ from .utils.pe_utils import PositionEmbeddingRandom
6
+ from torch import Tensor
7
+
8
+
9
+ class CrossAttnPromptModel(nn.Module):
10
+ def __init__(
11
+ self,
12
+ transformer_dim: int = 1024,
13
+ num_blocks: int = 1,
14
+ num_heads: int = 4,
15
+ pe: str = "normal",
16
+ image_pe_method: str = "patch", # image
17
+ **kwargs,
18
+ ) -> None:
19
+ """
20
+ Predicts masks given an image and prompt embeddings, using a
21
+ transformer architecture.
22
+
23
+ Arguments:
24
+ transformer_dim (int): the channel dimension of the transformer
25
+ transformer (nn.Module): the transformer used to predict masks
26
+ num_multimask_outputs (int): the number of masks to predict
27
+ when disambiguating masks
28
+ activation (nn.Module): the type of activation to use when
29
+ upscaling masks
30
+ iou_head_depth (int): the depth of the MLP used to predict
31
+ mask quality
32
+ iou_head_hidden_dim (int): the hidden dimension of the MLP
33
+ used to predict mask quality
34
+ """
35
+ super().__init__()
36
+ self.pe = pe
37
+ pe_dim = transformer_dim // 2
38
+ if self.pe == "apply":
39
+ pe_dim = pe_dim // num_heads
40
+ self.pe_layer = PositionEmbeddingRandom(pe_dim, image_pe_method=image_pe_method)
41
+ self.prompt_blocks = nn.ModuleList(
42
+ [
43
+ CrossAttenPromptBlock(dim=transformer_dim, num_heads=num_heads, first_block=(i == 0), pe=pe)
44
+ for i in range(num_blocks)
45
+ ]
46
+ )
47
+ self.depth2feature = nn.Sequential(
48
+ nn.Linear(1, transformer_dim // 2),
49
+ nn.GELU(),
50
+ nn.Linear(transformer_dim // 2, transformer_dim),
51
+ )
52
+
53
+ def forward(
54
+ self,
55
+ image_embeddings: torch.Tensor,
56
+ prompt_depth: torch.Tensor,
57
+ prompt_mask: torch.Tensor,
58
+ patch_h: int,
59
+ patch_w: int,
60
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
61
+ """
62
+ Predict masks given image and prompt embeddings.
63
+ Arguments:
64
+ image_embeddings (torch.Tensor): the embeddings from the image encoder
65
+ image_pe (torch.Tensor): positional encoding with the shape of image_embeddings
66
+ sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes
67
+ dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs
68
+ multimask_output (bool): Whether to return multiple masks or a single
69
+ mask.
70
+ Returns:
71
+ torch.Tensor: batched predicted masks
72
+ torch.Tensor: batched predictions of mask quality
73
+ """
74
+ B, _, H, W = prompt_depth.shape
75
+ image_pe = self.pe_layer((patch_h, patch_w)).permute(1, 2, 0) # CxHxW -> HxWxC
76
+ prompt_embeddings_list = []
77
+ image_embeddings_list = []
78
+ for b in range(B):
79
+ valid_pts_num = (prompt_mask[b, 0] > 0.0).sum()
80
+ if valid_pts_num == 0:
81
+ image_embeddings_item = image_embeddings[b : (b + 1)]
82
+ image_embeddings_list.append(image_embeddings_item)
83
+ continue
84
+ sparse_depth_pos = (prompt_mask[b, 0] > 0.0).nonzero().float()
85
+ sparse_depth_pos[:, 0] = (sparse_depth_pos[:, 0] + 0.5) / H
86
+ sparse_depth_pos[:, 1] = (sparse_depth_pos[:, 1] + 0.5) / W
87
+ sparse_depth = prompt_depth[b, 0][prompt_mask[b, 0] > 0.0]
88
+ prompt_embeddings = self.depth2feature(sparse_depth[:, None])[None, ...] # 1, N, C
89
+ prompt_pe = self.pe_layer._pe_encoding(sparse_depth_pos[None, :, [1, 0]]) # 1, N, C
90
+ query_pe = image_pe.reshape(1, -1, image_pe.shape[-1])
91
+ prompt = prompt_embeddings # + prompt_pe
92
+ query = image_embeddings[b : (b + 1)] # + query_pe
93
+ for block in self.prompt_blocks:
94
+ query, prompt = block(query, query_pe, prompt, prompt_pe)
95
+ image_embeddings_list.append(query[..., : image_embeddings.shape[-1]])
96
+ prompt_embeddings_list.append(prompt)
97
+ image_embeddings = torch.cat(image_embeddings_list, dim=0)
98
+ return image_embeddings
99
+
100
+
101
+ class CrossAttenPromptBlock(nn.Module):
102
+ """
103
+ Self-attention block for prompt-based processing that handles both query and context features.
104
+ Supports different positional encoding strategies.
105
+ """
106
+
107
+ def __init__(
108
+ self,
109
+ dim: int,
110
+ num_heads: int,
111
+ init_values: float = 0.0,
112
+ first_block: bool = False,
113
+ pe: str = "normal",
114
+ **kwargs,
115
+ ) -> None:
116
+ super().__init__()
117
+ self.first_block = first_block
118
+ self.pe = pe
119
+
120
+ # Attention components
121
+ self.norm1_x = nn.LayerNorm(dim)
122
+ self.norm1_x_after = nn.LayerNorm(dim)
123
+ self.attn_x = MemEffCrossAttention(dim, context_dim=dim, num_heads=num_heads, pe=pe)
124
+ self.ls1_x = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
125
+ self.norm1_context = nn.LayerNorm(dim)
126
+ self.attn_context = MemEffCrossAttention(dim, context_dim=dim, num_heads=num_heads, pe=pe)
127
+ self.ls1_context = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
128
+
129
+ # MLP components
130
+ self.norm2_x = nn.LayerNorm(dim)
131
+ self.mlp_x = Mlp(dim, hidden_features=dim * 4)
132
+ self.ls2_x = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
133
+ self.norm2_context = nn.LayerNorm(dim)
134
+ self.mlp_context = Mlp(dim, hidden_features=dim * 4)
135
+ self.ls2_context = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
136
+
137
+ def forward(self, x: Tensor, x_pe: Tensor, context: Tensor, context_pe: Tensor) -> Tuple[Tensor, Tensor]:
138
+ # Apply positional encoding if this is the first block and using normal PE
139
+ if self.pe == "normal" and self.first_block:
140
+ x = x + x_pe
141
+ context = context + context_pe
142
+
143
+ # Handle positional encoding concatenation if needed
144
+ if self.pe != "normal":
145
+ x = x + self.ls1_x(
146
+ self.attn_x(self.norm1_x(x), context=self.norm1_context(context), x_pe=x_pe, context_pe=context_pe)
147
+ )
148
+ context = context + self.ls1_context(
149
+ self.attn_context(
150
+ self.norm1_context(context), context=self.norm1_x_after(x), x_pe=context_pe, context_pe=x_pe
151
+ )
152
+ )
153
+ else:
154
+ # Apply standard attention
155
+ x = x + self.ls1_x(self.attn_x(self.norm1_x(x), context=self.norm1_context(context)))
156
+ context = context + self.ls1_context(
157
+ self.attn_context(self.norm1_context(context), context=self.norm1_x_after(x))
158
+ )
159
+
160
+ # Apply MLP
161
+ x = x + self.ls2_x(self.mlp_x(self.norm2_x(x)))
162
+ context = context + self.ls2_context(self.mlp_context(self.norm2_context(context)))
163
+
164
+ return x, context
InfiniDepth/model/block/prompt_models/rope.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Tuple
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ if torch.cuda.is_available():
7
+ acc_dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16
8
+ else:
9
+ acc_dtype = torch.float16
10
+
11
+
12
+ class PositionGetter:
13
+ """Generates and caches 2D spatial positions for patches in a grid.
14
+
15
+ This class efficiently manages the generation of spatial coordinates for patches
16
+ in a 2D grid, caching results to avoid redundant computations.
17
+
18
+ Attributes:
19
+ position_cache: Dictionary storing precomputed position tensors for different
20
+ grid dimensions.
21
+ """
22
+
23
+ def __init__(self):
24
+ """Initializes the position generator with an empty cache."""
25
+ self.position_cache: Dict[Tuple[int, int], torch.Tensor] = {}
26
+
27
+ def __call__(self, batch_size: int, height: int, width: int, device: torch.device) -> torch.Tensor:
28
+ """Generates spatial positions for a batch of patches.
29
+
30
+ Args:
31
+ batch_size: Number of samples in the batch.
32
+ height: Height of the grid in patches.
33
+ width: Width of the grid in patches.
34
+ device: Target device for the position tensor.
35
+
36
+ Returns:
37
+ Tensor of shape (batch_size, height*width, 2) containing y,x coordinates
38
+ for each position in the grid, repeated for each batch item.
39
+ """
40
+ if (height, width) not in self.position_cache:
41
+ y_coords = torch.arange(height, device=device)
42
+ x_coords = torch.arange(width, device=device)
43
+ positions = torch.cartesian_prod(y_coords, x_coords)
44
+ self.position_cache[height, width] = positions
45
+
46
+ cached_positions = self.position_cache[height, width]
47
+ return cached_positions.view(1, height * width, 2).expand(batch_size, -1, -1).clone()
48
+
49
+
50
+ class RotaryPositionEmbedding2D(nn.Module):
51
+ """2D Rotary Position Embedding implementation.
52
+
53
+ This module applies rotary position embeddings to input tokens based on their
54
+ 2D spatial positions. It handles the position-dependent rotation of features
55
+ separately for vertical and horizontal dimensions.
56
+
57
+ Args:
58
+ frequency: Base frequency for the position embeddings. Default: 100.0
59
+ scaling_factor: Scaling factor for frequency computation. Default: 1.0
60
+
61
+ Attributes:
62
+ base_frequency: Base frequency for computing position embeddings.
63
+ scaling_factor: Factor to scale the computed frequencies.
64
+ frequency_cache: Cache for storing precomputed frequency components.
65
+ """
66
+
67
+ def __init__(self, frequency: float = 100.0, scaling_factor: float = 1.0, feat_dim: int = 1024):
68
+ """Initializes the 2D RoPE module."""
69
+ super().__init__()
70
+ self.base_frequency = frequency
71
+ self.scaling_factor = scaling_factor
72
+ self.frequency_cache: Dict[Tuple, Tuple[torch.Tensor, torch.Tensor]] = {}
73
+ self.patch_size = 14
74
+ self.feat_dim = feat_dim
75
+
76
+ def _compute_frequency_components(
77
+ self, dim: int, seq_len: int, device: torch.device, dtype: torch.dtype
78
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
79
+ """Computes frequency components for rotary embeddings.
80
+
81
+ Args:
82
+ dim: Feature dimension (must be even).
83
+ seq_len: Maximum sequence length.
84
+ device: Target device for computations.
85
+ dtype: Data type for the computed tensors.
86
+
87
+ Returns:
88
+ Tuple of (cosine, sine) tensors for frequency components.
89
+ """
90
+ cache_key = (dim, seq_len, device, dtype)
91
+ if cache_key not in self.frequency_cache:
92
+ # Compute frequency bands
93
+ exponents = torch.arange(0, dim, 2, device=device).float() / dim
94
+ inv_freq = 1.0 / (self.base_frequency**exponents)
95
+
96
+ # Generate position-dependent frequencies
97
+ positions = torch.arange(seq_len, device=device, dtype=inv_freq.dtype)
98
+ angles = torch.einsum("i,j->ij", positions, inv_freq)
99
+
100
+ # Compute and cache frequency components
101
+ angles = angles.to(dtype)
102
+ angles = torch.cat((angles, angles), dim=-1)
103
+ cos_components = angles.cos().to(dtype)
104
+ sin_components = angles.sin().to(dtype)
105
+ self.frequency_cache[cache_key] = (cos_components, sin_components)
106
+
107
+ return self.frequency_cache[cache_key]
108
+
109
+ @staticmethod
110
+ def _rotate_features(x: torch.Tensor) -> torch.Tensor:
111
+ """Performs feature rotation by splitting and recombining feature dimensions.
112
+
113
+ Args:
114
+ x: Input tensor to rotate.
115
+
116
+ Returns:
117
+ Rotated feature tensor.
118
+ """
119
+ feature_dim = x.shape[-1]
120
+ x1, x2 = x[..., : feature_dim // 2], x[..., feature_dim // 2 :]
121
+ return torch.cat((-x2, x1), dim=-1)
122
+
123
+ def _apply_1d_rope(
124
+ self,
125
+ tokens: torch.Tensor,
126
+ positions: torch.Tensor,
127
+ cos_comp: torch.Tensor,
128
+ sin_comp: torch.Tensor,
129
+ ) -> torch.Tensor:
130
+ """Applies 1D rotary position embeddings along one dimension.
131
+
132
+ Args:
133
+ tokens: Input token features.
134
+ positions: Position indices.
135
+ cos_comp: Cosine components for rotation.
136
+ sin_comp: Sine components for rotation.
137
+
138
+ Returns:
139
+ Tokens with applied rotary position embeddings.
140
+ """
141
+ # Embed positions with frequency components
142
+ cos = F.embedding(positions, cos_comp)[:, None, :, :]
143
+ sin = F.embedding(positions, sin_comp)[:, None, :, :]
144
+
145
+ # Apply rotation
146
+ return (tokens * cos) + (self._rotate_features(tokens) * sin)
147
+
148
+ def _forward(self, tokens: torch.Tensor, positions: torch.Tensor) -> torch.Tensor:
149
+ """Applies 2D rotary position embeddings to input tokens.
150
+
151
+ Args:
152
+ tokens: Input tensor of shape (batch_size, n_heads, n_tokens, dim).
153
+ The feature dimension (dim) must be divisible by 4.
154
+ positions: Position tensor of shape (batch_size, n_tokens, 2) containing
155
+ the y and x coordinates for each token.
156
+
157
+ Returns:
158
+ Tensor of same shape as input with applied 2D rotary position embeddings.
159
+
160
+ Raises:
161
+ AssertionError: If input dimensions are invalid or positions are malformed.
162
+ """
163
+ # Validate inputs
164
+ assert tokens.size(-1) % 2 == 0, "Feature dimension must be even"
165
+ assert positions.ndim == 3 and positions.shape[-1] == 2, "Positions must have shape (batch_size, n_tokens, 2)"
166
+
167
+ # Compute feature dimension for each spatial direction
168
+ feature_dim = tokens.size(-1) // 2
169
+
170
+ # Get frequency components
171
+ max_position = int(positions.max()) + 1
172
+ cos_comp, sin_comp = self._compute_frequency_components(feature_dim, max_position, tokens.device, tokens.dtype)
173
+
174
+ # Split features for vertical and horizontal processing
175
+ vertical_features, horizontal_features = tokens.chunk(2, dim=-1)
176
+
177
+ # Apply RoPE separately for each dimension
178
+ vertical_features = self._apply_1d_rope(vertical_features, positions[..., 0], cos_comp, sin_comp)
179
+ horizontal_features = self._apply_1d_rope(horizontal_features, positions[..., 1], cos_comp, sin_comp)
180
+
181
+ # Combine processed features
182
+ return torch.cat((vertical_features, horizontal_features), dim=-1)
183
+
184
+ def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
185
+ """Positionally encode points that are normalized to [0,1]."""
186
+ # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
187
+ max_position = int(coords.max()) + 1
188
+ cos_comp, sin_comp = self._compute_frequency_components(self.feat_dim, max_position, coords.device, acc_dtype)
189
+ vertical_cos = F.embedding(coords[..., 0], cos_comp)
190
+ vertical_sin = F.embedding(coords[..., 0], sin_comp)
191
+ horizontal_cos = F.embedding(coords[..., 1], cos_comp)
192
+ horizontal_sin = F.embedding(coords[..., 1], sin_comp)
193
+ # outputs d_1 x ... x d_n x C shape
194
+ return torch.cat((vertical_cos, vertical_sin, horizontal_cos, horizontal_sin), dim=-1)
195
+
196
+ def forward_encoding(self, size: Tuple[int, int], device: torch.device) -> torch.Tensor:
197
+ """Generate positional encoding for a grid of the specified size."""
198
+ height, width = size
199
+ y_coords = torch.arange(height, device=device) * (self.patch_size * 2) + self.patch_size - 1
200
+ x_coords = torch.arange(width, device=device) * (self.patch_size * 2) + self.patch_size - 1
201
+ positions = torch.cartesian_prod(y_coords, x_coords) # h, w
202
+ max_position = int(positions.max()) + 1
203
+ cos_comp, sin_comp = self._compute_frequency_components(self.feat_dim, max_position, device, acc_dtype)
204
+ vertical_cos = F.embedding(positions[..., 0], cos_comp)
205
+ vertical_sin = F.embedding(positions[..., 0], sin_comp)
206
+ horizontal_cos = F.embedding(positions[..., 1], cos_comp)
207
+ horizontal_sin = F.embedding(positions[..., 1], sin_comp)
208
+ return (
209
+ torch.cat((vertical_cos, vertical_sin, horizontal_cos, horizontal_sin), dim=-1)
210
+ .reshape(height, width, -1)
211
+ .permute(2, 0, 1)
212
+ )
213
+
214
+ def forward(self, size: Tuple[int, int], device: torch.device) -> torch.Tensor:
215
+ return self.forward_encoding(size, device)
InfiniDepth/model/block/prompt_models/sam.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import Tuple, Type
8
+ import torch
9
+ from .utils.pe_utils import PositionEmbeddingRandom
10
+ from .utils.transformer import TwoWayTransformer
11
+ from torch import nn
12
+
13
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
14
+ # All rights reserved.
15
+
16
+ # This source code is licensed under the license found in the
17
+ # LICENSE file in the root directory of this source tree.
18
+
19
+
20
+ class SAMPromptModel(nn.Module):
21
+ def __init__(
22
+ self,
23
+ *,
24
+ transformer_dim: int,
25
+ mlp_dim: int = 2048,
26
+ num_heads: int = 8,
27
+ activation: Type[nn.Module] = nn.GELU,
28
+ ) -> None:
29
+ """
30
+ Predicts masks given an image and prompt embeddings, using a
31
+ transformer architecture.
32
+
33
+ Arguments:
34
+ transformer_dim (int): the channel dimension of the transformer
35
+ transformer (nn.Module): the transformer used to predict masks
36
+ num_multimask_outputs (int): the number of masks to predict
37
+ when disambiguating masks
38
+ activation (nn.Module): the type of activation to use when
39
+ upscaling masks
40
+ iou_head_depth (int): the depth of the MLP used to predict
41
+ mask quality
42
+ iou_head_hidden_dim (int): the hidden dimension of the MLP
43
+ used to predict mask quality
44
+ """
45
+ super().__init__()
46
+ self.transformer_dim = transformer_dim
47
+ self.transformer = TwoWayTransformer(
48
+ depth=2, embedding_dim=transformer_dim, num_heads=num_heads, mlp_dim=mlp_dim
49
+ )
50
+ self.pe_layer = PositionEmbeddingRandom(transformer_dim // 2)
51
+ self.depth2feature = nn.Sequential(
52
+ nn.Linear(1, transformer_dim // 2), nn.ReLU(True), nn.Linear(transformer_dim // 2, transformer_dim)
53
+ )
54
+
55
+ def forward(
56
+ self,
57
+ image_embeddings: torch.Tensor,
58
+ prompt_depth: torch.Tensor,
59
+ prompt_mask: torch.Tensor,
60
+ patch_h: int,
61
+ patch_w: int,
62
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
63
+ """
64
+ Predict masks given image and prompt embeddings.
65
+ Arguments:
66
+ image_embeddings (torch.Tensor): the embeddings from the image encoder
67
+ image_pe (torch.Tensor): positional encoding with the shape of image_embeddings
68
+ sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes
69
+ dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs
70
+ multimask_output (bool): Whether to return multiple masks or a single
71
+ mask.
72
+ Returns:
73
+ torch.Tensor: batched predicted masks
74
+ torch.Tensor: batched predictions of mask quality
75
+ """
76
+ B, _, H, W = prompt_depth.shape
77
+ image_pe = self.pe_layer((patch_h, patch_w)).permute(1, 2, 0) # CxHxW -> HxWxC
78
+
79
+ # prompt_embeddings_list = []
80
+ image_embeddings_list = []
81
+ for b in range(B):
82
+ valid_pts_num = (prompt_mask[b, 0] > 0.0).sum()
83
+ if valid_pts_num == 0:
84
+ image_embeddings_item = image_embeddings[b : (b + 1)]
85
+ image_embeddings_list.append(image_embeddings_item)
86
+ continue
87
+ sparse_depth_pos = (prompt_mask[b, 0] > 0.0).nonzero().float()
88
+ sparse_depth_pos[:, 0] = (sparse_depth_pos[:, 0] + 0.5) / H
89
+ sparse_depth_pos[:, 1] = (sparse_depth_pos[:, 1] + 0.5) / W
90
+ sparse_depth = prompt_depth[b, 0][prompt_mask[b, 0] > 0.0]
91
+ prompt_embeddings = self.depth2feature(sparse_depth[:, None])[None, ...] # 1, N, C
92
+ prompt_pe = self.pe_layer._pe_encoding(sparse_depth_pos[None, :, [1, 0]]) # 1, N, C
93
+ prompt_embeddings_item, image_embeddings_item = self.transformer(
94
+ image_embeddings[b : (b + 1)],
95
+ image_pe.reshape(1, -1, image_pe.shape[-1]),
96
+ prompt_embeddings,
97
+ prompt_pe,
98
+ )
99
+ image_embeddings_list.append(image_embeddings_item)
100
+ image_embeddings = torch.cat(image_embeddings_list, dim=0)
101
+ return image_embeddings
102
+
103
+
104
+ # # Lightly adapted from
105
+ # # https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa
106
+ # class MLP(nn.Module):
107
+ # def __init__(
108
+ # self,
109
+ # input_dim: int,
110
+ # hidden_dim: int,
111
+ # output_dim: int,
112
+ # num_layers: int,
113
+ # sigmoid_output: bool = False,
114
+ # ) -> None:
115
+ # super().__init__()
116
+ # self.num_layers = num_layers
117
+ # h = [hidden_dim] * (num_layers - 1)
118
+ # self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
119
+ # self.sigmoid_output = sigmoid_output
120
+
121
+ # def forward(self, x):
122
+ # for i, layer in enumerate(self.layers):
123
+ # x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
124
+ # if self.sigmoid_output:
125
+ # x = F.sigmoid(x)
126
+ # return x
InfiniDepth/model/block/prompt_models/selfattn.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+ import torch
3
+ import torch.nn as nn
4
+ from ..perceive_io import LayerScale, MemEffAttention, Mlp
5
+ from .rope import RotaryPositionEmbedding2D
6
+ from .utils.pe_utils import PositionEmbeddingRandom
7
+ from torch import Tensor
8
+
9
+ if torch.cuda.is_available():
10
+ acc_dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16
11
+ else:
12
+ acc_dtype = torch.float16
13
+
14
+
15
+ class SelfAttnPromptModel(nn.Module):
16
+ def __init__(
17
+ self,
18
+ transformer_dim: int = 1024,
19
+ num_blocks: int = 1,
20
+ num_heads: int = 4,
21
+ pe: str = "normal",
22
+ image_pe_method: str = "patch", # image
23
+ **kwargs,
24
+ ) -> None:
25
+ """
26
+ Predicts masks given an image and prompt embeddings, using a
27
+ transformer architecture.
28
+
29
+ Arguments:
30
+ transformer_dim (int): the channel dimension of the transformer
31
+ transformer (nn.Module): the transformer used to predict masks
32
+ num_multimask_outputs (int): the number of masks to predict
33
+ when disambiguating masks
34
+ activation (nn.Module): the type of activation to use when
35
+ upscaling masks
36
+ iou_head_depth (int): the depth of the MLP used to predict
37
+ mask quality
38
+ iou_head_hidden_dim (int): the hidden dimension of the MLP
39
+ used to predict mask quality
40
+ """
41
+ super().__init__()
42
+ self.pe = pe
43
+ pe_dim = transformer_dim // 2
44
+ if self.pe == "apply":
45
+ pe_dim = pe_dim // num_heads
46
+ self.pe_layer = PositionEmbeddingRandom(pe_dim, image_pe_method=image_pe_method)
47
+ self.prompt_blocks = nn.ModuleList(
48
+ [
49
+ SelfAttenPromptBlock(dim=transformer_dim, num_heads=num_heads, first_block=(i == 0), pe=pe)
50
+ for i in range(num_blocks)
51
+ ]
52
+ )
53
+ self.depth2feature = nn.Sequential(
54
+ nn.Linear(1, transformer_dim // 2),
55
+ nn.GELU(),
56
+ nn.Linear(transformer_dim // 2, transformer_dim),
57
+ )
58
+
59
+ def forward(
60
+ self,
61
+ image_embeddings: torch.Tensor,
62
+ prompt_depth: torch.Tensor,
63
+ prompt_mask: torch.Tensor,
64
+ patch_h: int,
65
+ patch_w: int,
66
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
67
+ """
68
+ Predict masks given image and prompt embeddings.
69
+ Arguments:
70
+ image_embeddings (torch.Tensor): the embeddings from the image encoder
71
+ image_pe (torch.Tensor): positional encoding with the shape of image_embeddings
72
+ sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes
73
+ dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs
74
+ multimask_output (bool): Whether to return multiple masks or a single
75
+ mask.
76
+ Returns:
77
+ torch.Tensor: batched predicted masks
78
+ torch.Tensor: batched predictions of mask quality
79
+ """
80
+ B, _, H, W = prompt_depth.shape
81
+ image_pe = self.pe_layer((patch_h, patch_w)).permute(1, 2, 0) # CxHxW -> HxWxC
82
+ prompt_embeddings_list = []
83
+ image_embeddings_list = []
84
+ for b in range(B):
85
+ valid_pts_num = (prompt_mask[b, 0] > 0.0).sum()
86
+ if valid_pts_num == 0:
87
+ image_embeddings_item = image_embeddings[b : (b + 1)]
88
+ image_embeddings_list.append(image_embeddings_item)
89
+ continue
90
+ sparse_depth_pos = (prompt_mask[b, 0] > 0.0).nonzero().float()
91
+ sparse_depth_pos[:, 0] = (sparse_depth_pos[:, 0] + 0.5) / H
92
+ sparse_depth_pos[:, 1] = (sparse_depth_pos[:, 1] + 0.5) / W
93
+ sparse_depth = prompt_depth[b, 0][prompt_mask[b, 0] > 0.0]
94
+ prompt_embeddings = self.depth2feature(sparse_depth[:, None])[None, ...] # 1, N, C
95
+ prompt_pe = self.pe_layer._pe_encoding(sparse_depth_pos[None, :, [1, 0]]) # 1, N, C
96
+ query_pe = image_pe.reshape(1, -1, image_pe.shape[-1])
97
+ prompt = prompt_embeddings # + prompt_pe
98
+ query = image_embeddings[b : (b + 1)] # + query_pe
99
+ with torch.autocast("cuda", enabled=True, dtype=acc_dtype):
100
+ for block in self.prompt_blocks:
101
+ query, prompt = block(query, query_pe, prompt, prompt_pe)
102
+ image_embeddings_list.append(query[..., : image_embeddings.shape[-1]])
103
+ prompt_embeddings_list.append(prompt)
104
+ image_embeddings = torch.cat(image_embeddings_list, dim=0)
105
+ return image_embeddings
106
+
107
+
108
+ class SelfAttnRopePromptModel(nn.Module):
109
+ def __init__(
110
+ self,
111
+ transformer_dim: int = 1024,
112
+ num_blocks: int = 1,
113
+ num_heads: int = 4,
114
+ pe: str = "normal",
115
+ image_pe_method: str = "patch", # image
116
+ **kwargs,
117
+ ) -> None:
118
+ """
119
+ Predicts masks given an image and prompt embeddings, using a
120
+ transformer architecture.
121
+
122
+ Arguments:
123
+ transformer_dim (int): the channel dimension of the transformer
124
+ transformer (nn.Module): the transformer used to predict masks
125
+ num_multimask_outputs (int): the number of masks to predict
126
+ when disambiguating masks
127
+ activation (nn.Module): the type of activation to use when
128
+ upscaling masks
129
+ iou_head_depth (int): the depth of the MLP used to predict
130
+ mask quality
131
+ iou_head_hidden_dim (int): the hidden dimension of the MLP
132
+ used to predict mask quality
133
+ """
134
+ super().__init__()
135
+ self.pe = pe
136
+ pe_dim = transformer_dim // 2
137
+ if self.pe == "apply":
138
+ pe_dim = pe_dim // num_heads
139
+ if self.pe.startswith("rope"):
140
+ self.pe_layer = RotaryPositionEmbedding2D(
141
+ frequency=float(self.pe.split("rope")[1]), feat_dim=pe_dim // num_heads
142
+ )
143
+ else:
144
+ self.pe_layer = PositionEmbeddingRandom(pe_dim, image_pe_method=image_pe_method)
145
+ self.prompt_blocks = nn.ModuleList(
146
+ [
147
+ SelfAttenPromptBlock(
148
+ dim=transformer_dim, num_heads=num_heads, first_block=(i == 0), pe=pe, use_sep=False
149
+ )
150
+ for i in range(num_blocks)
151
+ ]
152
+ )
153
+ self.depth2feature = nn.Sequential(
154
+ nn.Linear(1, transformer_dim // 2),
155
+ nn.GELU(),
156
+ nn.Linear(transformer_dim // 2, transformer_dim),
157
+ )
158
+
159
+ def forward(
160
+ self,
161
+ image_embeddings: torch.Tensor,
162
+ prompt_depth: torch.Tensor,
163
+ prompt_mask: torch.Tensor,
164
+ patch_h: int,
165
+ patch_w: int,
166
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
167
+ """
168
+ Predict masks given image and prompt embeddings.
169
+ Arguments:
170
+ image_embeddings (torch.Tensor): the embeddings from the image encoder
171
+ image_pe (torch.Tensor): positional encoding with the shape of image_embeddings
172
+ sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes
173
+ dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs
174
+ multimask_output (bool): Whether to return multiple masks or a single
175
+ mask.
176
+ Returns:
177
+ torch.Tensor: batched predicted masks
178
+ torch.Tensor: batched predictions of mask quality
179
+ """
180
+ B, _, H, W = prompt_depth.shape
181
+ image_pe = self.pe_layer((patch_h, patch_w), device=prompt_depth.device).permute(1, 2, 0) # CxHxW -> HxWxC
182
+ prompt_embeddings_list = []
183
+ image_embeddings_list = []
184
+ for b in range(B):
185
+ valid_pts_num = (prompt_mask[b, 0] > 0.0).sum()
186
+ if valid_pts_num == 0:
187
+ image_embeddings_item = image_embeddings[b : (b + 1)]
188
+ image_embeddings_list.append(image_embeddings_item)
189
+ continue
190
+ sparse_depth_pos = (prompt_mask[b, 0] > 0.0).nonzero().int()
191
+ sparse_depth_pos[:, 0] = sparse_depth_pos[:, 0] * 2
192
+ sparse_depth_pos[:, 1] = sparse_depth_pos[:, 1] * 2
193
+ sparse_depth = prompt_depth[b, 0][prompt_mask[b, 0] > 0.0]
194
+ prompt_embeddings = self.depth2feature(sparse_depth[:, None])[None, ...] # 1, N, C
195
+ prompt_pe = self.pe_layer._pe_encoding(sparse_depth_pos[None]) # 1, N, C
196
+ query_pe = image_pe.reshape(1, -1, image_pe.shape[-1])
197
+ prompt = prompt_embeddings # + prompt_pe
198
+ query = image_embeddings[b : (b + 1)] # + query_pe
199
+ with torch.autocast("cuda", enabled=True, dtype=acc_dtype):
200
+ for block in self.prompt_blocks:
201
+ query, prompt = block(query, query_pe, prompt, prompt_pe)
202
+ image_embeddings_list.append(query[..., : image_embeddings.shape[-1]])
203
+ prompt_embeddings_list.append(prompt)
204
+ image_embeddings = torch.cat(image_embeddings_list, dim=0)
205
+ return image_embeddings
206
+
207
+
208
+ class SelfAttenPromptBlock(nn.Module):
209
+ """
210
+ Self-attention block for prompt-based processing that handles both query and context features.
211
+ Supports different positional encoding strategies.
212
+ """
213
+
214
+ def __init__(
215
+ self,
216
+ dim: int,
217
+ num_heads: int,
218
+ init_values: float = 0.0,
219
+ first_block: bool = False,
220
+ pe: str = "normal",
221
+ use_sep: bool = True,
222
+ **kwargs,
223
+ ) -> None:
224
+ super().__init__()
225
+ self.first_block = first_block
226
+ self.pe = pe
227
+
228
+ # Attention components
229
+ self.norm1 = nn.LayerNorm(dim)
230
+ self.attn = MemEffAttention(dim, num_heads=num_heads, pe=pe)
231
+ self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
232
+
233
+ # MLP components
234
+ self.norm2 = nn.LayerNorm(dim)
235
+ self.mlp = Mlp(dim, hidden_features=dim * 4)
236
+ self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
237
+
238
+ # Separator token for concatenating query and context
239
+ pe_dim = dim
240
+ self.use_sep = use_sep
241
+ if use_sep:
242
+ self.sep = nn.Parameter(torch.randn(1, 1, pe_dim))
243
+ else:
244
+ self.sep = None
245
+
246
+ # Special separator for positional encoding if needed
247
+ if self.use_sep:
248
+ if self.pe != "normal":
249
+ self.sep_pe = nn.Parameter(torch.randn(1, 1, pe_dim))
250
+ else:
251
+ self.sep_pe = None
252
+
253
+ def forward(self, x: Tensor, x_pe: Tensor, context: Tensor, context_pe: Tensor) -> Tuple[Tensor, Tensor]:
254
+ # Apply positional encoding if this is the first block and using normal PE
255
+ if self.pe == "normal" and self.first_block:
256
+ x = x + x_pe
257
+ context = context + context_pe
258
+
259
+ # Record original sequence lengths
260
+ x_len, context_len = x.shape[1], context.shape[1]
261
+
262
+ # Concatenate query, separator token, and context
263
+ if self.use_sep:
264
+ x = torch.cat([x, self.sep, context], dim=1)
265
+ else:
266
+ x = torch.cat([x, context], dim=1)
267
+
268
+ # Handle positional encoding concatenation if needed
269
+ if self.pe != "normal":
270
+ if self.use_sep:
271
+ x_pe = torch.cat([x_pe, self.sep_pe, context_pe], dim=1)
272
+ else:
273
+ x_pe = torch.cat([x_pe, context_pe], dim=1)
274
+ x = x + self.ls1(self.attn(self.norm1(x), x_pe))
275
+ else:
276
+ # Apply standard attention
277
+ x = x + self.ls1(self.attn(self.norm1(x)))
278
+
279
+ # Apply MLP
280
+ x = x + self.ls2(self.mlp(self.norm2(x)))
281
+
282
+ # Split back into query and context
283
+ query = x[:, :x_len, :]
284
+ if self.use_sep:
285
+ context = x[:, x_len + 1 : x_len + 1 + context_len, :]
286
+ else:
287
+ context = x[:, x_len : x_len + context_len, :]
288
+
289
+ return query, context
InfiniDepth/model/block/prompt_models/utils/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Prompt model utility modules."""
InfiniDepth/model/block/prompt_models/utils/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (189 Bytes). View file
 
InfiniDepth/model/block/prompt_models/utils/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (252 Bytes). View file
 
InfiniDepth/model/block/prompt_models/utils/__pycache__/pe_utils.cpython-310.pyc ADDED
Binary file (2.9 kB). View file
 
InfiniDepth/model/block/prompt_models/utils/__pycache__/pe_utils.cpython-311.pyc ADDED
Binary file (5.25 kB). View file
 
InfiniDepth/model/block/prompt_models/utils/__pycache__/transformer.cpython-310.pyc ADDED
Binary file (7.31 kB). View file
 
InfiniDepth/model/block/prompt_models/utils/__pycache__/transformer.cpython-311.pyc ADDED
Binary file (12.1 kB). View file
 
InfiniDepth/model/block/prompt_models/utils/pe_utils.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Optional, Tuple
2
+ import numpy as np
3
+ import torch
4
+ from torch import nn
5
+
6
+
7
+ class PositionEmbeddingRandom(nn.Module):
8
+ """
9
+ Positional encoding using random spatial frequencies.
10
+ """
11
+
12
+ patch_size = 14
13
+
14
+ def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None, image_pe_method: str = "patch") -> None:
15
+ super().__init__()
16
+ if scale is None or scale <= 0.0:
17
+ scale = 1.0
18
+ self.register_buffer(
19
+ "positional_encoding_gaussian_matrix",
20
+ scale * torch.randn((2, num_pos_feats)),
21
+ )
22
+
23
+ self.image_pe_method = image_pe_method
24
+ if self.image_pe_method == "image":
25
+ # self.patch_embed = nn.Conv2d(num_pos_feats*2, num_pos_feats*2, kernel_size=self.patch_size, stride=self.patch_size)
26
+ self.patch_embed = nn.Sequential(
27
+ nn.Conv2d(num_pos_feats * 2, num_pos_feats // 2, kernel_size=2, stride=2),
28
+ nn.ReLU(),
29
+ nn.Conv2d(
30
+ num_pos_feats // 2,
31
+ num_pos_feats * 2,
32
+ kernel_size=self.patch_size // 2,
33
+ stride=self.patch_size // 2,
34
+ ),
35
+ )
36
+
37
+ def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
38
+ """Positionally encode points that are normalized to [0,1]."""
39
+ # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
40
+ coords = 2 * coords - 1
41
+ coords = coords @ self.positional_encoding_gaussian_matrix
42
+ coords = 2 * np.pi * coords
43
+ # outputs d_1 x ... x d_n x C shape
44
+ return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
45
+
46
+ def forward_encoding(self, size: Tuple[int, int]) -> torch.Tensor:
47
+ """Generate positional encoding for a grid of the specified size."""
48
+ h, w = size
49
+ device: Any = self.positional_encoding_gaussian_matrix.device
50
+ grid = torch.ones((h, w), device=device, dtype=torch.float32)
51
+ y_embed = grid.cumsum(dim=0) - 0.5
52
+ x_embed = grid.cumsum(dim=1) - 0.5
53
+ y_embed = y_embed / h
54
+ x_embed = x_embed / w
55
+ pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1)) # HxWx2 -> HxWxC
56
+ return pe.permute(2, 0, 1) # C x H x W
57
+
58
+ def forward(self, size: Tuple[int, int]) -> torch.Tensor:
59
+
60
+ if self.image_pe_method == "patch":
61
+ return self.forward_encoding(size)
62
+ elif self.image_pe_method == "image":
63
+ pe_encoding = self.forward_encoding(size)
64
+ pe_encoding_high = self.forward_encoding((size[0] * self.patch_size, size[1] * self.patch_size))
65
+ return pe_encoding + self.patch_embed(pe_encoding_high[None])[0]
66
+
67
+ def forward_with_coords(self, coords_input: torch.Tensor, image_size: Tuple[int, int]) -> torch.Tensor:
68
+ """Positionally encode points that are not normalized to [0,1]."""
69
+ coords = coords_input.clone()
70
+ coords[:, :, 0] = coords[:, :, 0] / image_size[1]
71
+ coords[:, :, 1] = coords[:, :, 1] / image_size[0]
72
+ return self._pe_encoding(coords.to(torch.float)) # B x N x C
InfiniDepth/model/block/prompt_models/utils/transformer.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+ from typing import Tuple, Type
9
+ import torch
10
+ from einops import rearrange
11
+ from ...common import MLPBlock
12
+ from .....utils.logger import Log
13
+ from torch import Tensor, nn
14
+
15
+ try:
16
+ from xformers.ops import memory_efficient_attention
17
+
18
+ XFORMERS_AVAILABLE = True
19
+ except ImportError:
20
+ Log.warning("xFormers not available")
21
+ XFORMERS_AVAILABLE = False
22
+
23
+
24
+ class TwoWayTransformer(nn.Module):
25
+ def __init__(
26
+ self,
27
+ depth: int,
28
+ embedding_dim: int,
29
+ num_heads: int,
30
+ mlp_dim: int,
31
+ activation: Type[nn.Module] = nn.ReLU,
32
+ attention_downsample_rate: int = 2,
33
+ ) -> None:
34
+ """
35
+ A transformer decoder that attends to an input image using
36
+ queries whose positional embedding is supplied.
37
+
38
+ Args:
39
+ depth (int): number of layers in the transformer
40
+ embedding_dim (int): the channel dimension for the input embeddings
41
+ num_heads (int): the number of heads for multihead attention. Must
42
+ divide embedding_dim
43
+ mlp_dim (int): the channel dimension internal to the MLP block
44
+ activation (nn.Module): the activation to use in the MLP block
45
+ """
46
+ super().__init__()
47
+ self.depth = depth
48
+ self.embedding_dim = embedding_dim
49
+ self.num_heads = num_heads
50
+ self.mlp_dim = mlp_dim
51
+ self.layers = nn.ModuleList()
52
+
53
+ for i in range(depth):
54
+ self.layers.append(
55
+ TwoWayAttentionBlock(
56
+ embedding_dim=embedding_dim,
57
+ num_heads=num_heads,
58
+ mlp_dim=mlp_dim,
59
+ activation=activation,
60
+ attention_downsample_rate=attention_downsample_rate,
61
+ skip_first_layer_pe=(i == 0),
62
+ )
63
+ )
64
+
65
+ self.final_attn_token_to_image = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate)
66
+ self.norm_final_attn = nn.LayerNorm(embedding_dim)
67
+
68
+ def forward(
69
+ self, image_embedding: Tensor, image_pe: Tensor, point_embedding: Tensor, point_pe: Tensor
70
+ ) -> Tuple[Tensor, Tensor]:
71
+ """
72
+ Args:
73
+ image_embedding (torch.Tensor): image to attend to. Should be shape
74
+ B x embedding_dim x h x w for any h and w.
75
+ image_pe (torch.Tensor): the positional encoding to add to the image. Must
76
+ have the same shape as image_embedding.
77
+ point_embedding (torch.Tensor): the embedding to add to the query points.
78
+ Must have shape B x N_points x embedding_dim for any N_points.
79
+
80
+ Returns:
81
+ torch.Tensor: the processed point_embedding
82
+ torch.Tensor: the processed image_embedding
83
+ """
84
+ # Prepare queries
85
+ queries = point_embedding
86
+ keys = image_embedding
87
+
88
+ # Apply transformer blocks and final layernorm
89
+ for layer in self.layers:
90
+ queries, keys = layer(
91
+ queries=queries,
92
+ keys=keys,
93
+ query_pe=point_pe,
94
+ key_pe=image_pe,
95
+ )
96
+ # queries become keys-aware; keys become queries-aware
97
+
98
+ # Apply the final attention layer from the points to the image
99
+ q = queries + point_embedding
100
+ k = keys + image_pe
101
+ attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
102
+ queries = queries + attn_out
103
+ queries = self.norm_final_attn(queries)
104
+
105
+ return queries, keys
106
+
107
+
108
+ class TwoWayAttentionBlock(nn.Module):
109
+ def __init__(
110
+ self,
111
+ embedding_dim: int,
112
+ num_heads: int,
113
+ mlp_dim: int = 2048,
114
+ activation: Type[nn.Module] = nn.ReLU,
115
+ attention_downsample_rate: int = 2,
116
+ skip_first_layer_pe: bool = False,
117
+ ) -> None:
118
+ """
119
+ A transformer block with four layers: (1) self-attention of sparse
120
+ inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp
121
+ block on sparse inputs, and (4) cross attention of dense inputs to sparse
122
+ inputs.
123
+
124
+ Arguments:
125
+ embedding_dim (int): the channel dimension of the embeddings
126
+ num_heads (int): the number of heads in the attention layers
127
+ mlp_dim (int): the hidden dimension of the mlp block
128
+ activation (nn.Module): the activation of the mlp block
129
+ skip_first_layer_pe (bool): skip the PE on the first layer
130
+ """
131
+ super().__init__()
132
+ self.self_attn = MemEffAttention(embedding_dim, num_heads)
133
+ self.norm1 = nn.LayerNorm(embedding_dim)
134
+
135
+ self.cross_attn_token_to_image = MemEffAttention(
136
+ embedding_dim, num_heads, downsample_rate=attention_downsample_rate
137
+ )
138
+ self.norm2 = nn.LayerNorm(embedding_dim)
139
+
140
+ self.mlp = MLPBlock(embedding_dim, mlp_dim, activation)
141
+ self.norm3 = nn.LayerNorm(embedding_dim)
142
+
143
+ self.norm4 = nn.LayerNorm(embedding_dim)
144
+ self.cross_attn_image_to_token = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate)
145
+
146
+ self.skip_first_layer_pe = skip_first_layer_pe
147
+
148
+ def forward(self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor) -> Tuple[Tensor, Tensor]:
149
+ # Self attention block
150
+ #
151
+ if self.skip_first_layer_pe:
152
+ queries = self.self_attn(q=queries, k=queries, v=queries)
153
+ else:
154
+ q = queries + query_pe
155
+ attn_out = self.self_attn(q=q, k=q, v=queries)
156
+ queries = queries + attn_out
157
+ queries = self.norm1(queries)
158
+
159
+ # Cross attention block, tokens attending to image embedding
160
+ q = queries + query_pe
161
+ k = keys + key_pe
162
+ attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
163
+ queries = queries + attn_out
164
+ queries = self.norm2(queries)
165
+
166
+ # MLP block
167
+ mlp_out = self.mlp(queries)
168
+ queries = queries + mlp_out
169
+ queries = self.norm3(queries)
170
+
171
+ # Cross attention block, image embedding attending to tokens
172
+ q = queries + query_pe
173
+ k = keys + key_pe
174
+ attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
175
+ keys = keys + attn_out
176
+ keys = self.norm4(keys)
177
+
178
+ return queries, keys
179
+
180
+
181
+ class Attention(nn.Module):
182
+ """
183
+ An attention layer that allows for downscaling the size of the embedding
184
+ after projection to queries, keys, and values.
185
+ """
186
+
187
+ def __init__(
188
+ self,
189
+ embedding_dim: int,
190
+ num_heads: int,
191
+ downsample_rate: int = 1,
192
+ ) -> None:
193
+ super().__init__()
194
+ self.embedding_dim = embedding_dim
195
+ self.internal_dim = embedding_dim // downsample_rate
196
+ self.num_heads = num_heads
197
+ assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim."
198
+
199
+ self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
200
+ self.k_proj = nn.Linear(embedding_dim, self.internal_dim)
201
+ self.v_proj = nn.Linear(embedding_dim, self.internal_dim)
202
+ self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
203
+
204
+ def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
205
+ b, n, c = x.shape
206
+ x = x.reshape(b, n, num_heads, c // num_heads)
207
+ return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head
208
+
209
+ def _recombine_heads(self, x: Tensor) -> Tensor:
210
+ b, n_heads, n_tokens, c_per_head = x.shape
211
+ x = x.transpose(1, 2)
212
+ return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
213
+
214
+ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
215
+ # Input projections
216
+ q = self.q_proj(q)
217
+ k = self.k_proj(k)
218
+ v = self.v_proj(v)
219
+
220
+ # Separate into heads
221
+ q = self._separate_heads(q, self.num_heads)
222
+ k = self._separate_heads(k, self.num_heads)
223
+ v = self._separate_heads(v, self.num_heads)
224
+
225
+ # Attention
226
+ _, _, _, c_per_head = q.shape
227
+ attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens
228
+ attn = attn / math.sqrt(c_per_head)
229
+ attn = torch.softmax(attn, dim=-1)
230
+
231
+ # Get output
232
+ out = attn @ v
233
+ out = self._recombine_heads(out)
234
+ out = self.out_proj(out)
235
+
236
+ return out
237
+
238
+
239
+ class MemEffAttention(Attention):
240
+ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
241
+ q = self.q_proj(q)
242
+ k = self.k_proj(k)
243
+ v = self.v_proj(v)
244
+ q = rearrange(q, "b n (m c) -> b n m c", m=self.num_heads)
245
+ k = rearrange(k, "b n (m c) -> b n m c", m=self.num_heads)
246
+ v = rearrange(v, "b n (m c) -> b n m c", m=self.num_heads)
247
+ x = memory_efficient_attention(q, k, v, attn_bias=None)
248
+ x = rearrange(x, "b n m c -> b n (m c)")
249
+ x = self.out_proj(x)
250
+ return x
InfiniDepth/model/block/rope.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+ import torch
3
+
4
+
5
+ def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0):
6
+ # assert H * H == end
7
+ # flat_patch_pos = torch.linspace(-1, 1, end) # N = end
8
+ x_pos = torch.linspace(0, scale, width)
9
+ y_pos = torch.linspace(0, scale, height)
10
+ y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij")
11
+ y_pos = y_pos.reshape(-1)
12
+ x_pos = x_pos.reshape(-1)
13
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4
14
+ x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4
15
+ y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4
16
+ x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs)
17
+ y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs)
18
+ freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2
19
+ freqs_cis = freqs_cis.reshape(height*width, -1)
20
+ return freqs_cis
21
+
22
+ def precompute_freqs_cis_ex2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=1.0):
23
+ if isinstance(scale, float):
24
+ scale = (scale, scale)
25
+ x_pos = torch.linspace(0, height*scale[0], width)
26
+ y_pos = torch.linspace(0, width*scale[1], height)
27
+ y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij")
28
+ y_pos = y_pos.reshape(-1)
29
+ x_pos = x_pos.reshape(-1)
30
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4
31
+ x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4
32
+ y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4
33
+ x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs)
34
+ y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs)
35
+ freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2
36
+ freqs_cis = freqs_cis.reshape(height*width, -1)
37
+ return freqs_cis
38
+
39
+
40
+ def apply_rotary_emb(
41
+ xq: torch.Tensor,
42
+ xk: torch.Tensor,
43
+ freqs_cis: torch.Tensor,
44
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
45
+ freqs_cis = freqs_cis[None, None, :, :]
46
+ # xq : B N H Hc
47
+ xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2
48
+ xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
49
+ xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc
50
+ xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
51
+ return xq_out.type_as(xq), xk_out.type_as(xk)
52
+
53
+ def apply_rotary_emb_crossattention(
54
+ xq: torch.Tensor,
55
+ xk: torch.Tensor,
56
+ yk: torch.Tensor,
57
+ freqs_cis1: torch.Tensor,
58
+ freqs_cis2: torch.Tensor,
59
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
60
+ freqs_cis1 = freqs_cis1[None, None, :, :]
61
+ freqs_cis2 = freqs_cis2[None, None, :, :]
62
+ # xq : B N H Hc
63
+ xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2
64
+ xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
65
+ yk_ = torch.view_as_complex(yk.float().reshape(*yk.shape[:-1], -1, 2))
66
+ xq_out = torch.view_as_real(xq_ * freqs_cis1).flatten(3) # B, N, H, Hc
67
+ xk_out = torch.view_as_real(xk_ * freqs_cis1).flatten(3)
68
+ yk_out = torch.view_as_real(yk_ * freqs_cis2).flatten(3)
69
+ return xq_out.type_as(xq), xk_out.type_as(xk), yk_out.type_as(yk)
InfiniDepth/model/block/torchhub/README.md ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # Local PyTorch Hub
2
+
3
+ This directory is for loading the DINOv2 encoder locally in case of no Internet connection.
InfiniDepth/model/block/torchhub/dinov3/.docstr.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ paths:
2
+ - dinov3
3
+ exclude: dinov3/tests
4
+ skip_init: True
5
+ skip_private: True
6
+ fail_under: 0
InfiniDepth/model/block/torchhub/dinov3/.github/workflows/lint.yaml ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Lint
2
+
3
+ on:
4
+ push:
5
+ branches:
6
+ - main
7
+ pull_request:
8
+ branches:
9
+ - main
10
+
11
+ jobs:
12
+ run-linters:
13
+ name: Run linters
14
+ runs-on: ubuntu-latest
15
+
16
+ steps:
17
+ - name: Checkout repository
18
+ uses: actions/checkout@v4
19
+ - name: Set up Python
20
+ uses: actions/setup-python@v5
21
+ with:
22
+ python-version: 3.11
23
+ cache: 'pip'
24
+ cache-dependency-path: '**/requirements*.txt'
25
+ - name: Install Python (development) dependencies
26
+ run: |
27
+ pip install -r requirements-dev.txt
28
+ - name: Run ruff (linter)
29
+ run: |
30
+ ruff check dinov3
31
+ - name: Run ruff (formatter)
32
+ if: always()
33
+ run: |
34
+ ruff format --diff dinov3
35
+ - name: Report docstring coverage
36
+ if: always()
37
+ run: |
38
+ docstr-coverage dinov3
39
+ - name: Run mypy
40
+ if: always()
41
+ run: |
42
+ mypy --txt-report .
43
+ [ -f index.txt ] && cat index.txt
44
+ - name: Run pylint
45
+ if: always()
46
+ run: |
47
+ pylint --exit-zero dinov3
InfiniDepth/model/block/torchhub/dinov3/.gitignore ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ build/
2
+ dist/
3
+ *.egg-info/
4
+ **/__pycache__/
5
+
6
+ **/.ipynb_checkpoints
7
+ **/.ipynb_checkpoints/**
8
+
9
+ **/notebooks
10
+
11
+ # Ignore shell scripts
12
+ *.sh
13
+
14
+ # Ignore swap files
15
+ *.swp
16
+
17
+ # Ignore vscode directory
18
+ .vscode/
InfiniDepth/model/block/torchhub/dinov3/CODE_OF_CONDUCT.md ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Code of Conduct
2
+
3
+ ## Our Pledge
4
+
5
+ In the interest of fostering an open and welcoming environment, we as
6
+ contributors and maintainers pledge to make participation in our project and
7
+ our community a harassment-free experience for everyone, regardless of age, body
8
+ size, disability, ethnicity, sex characteristics, gender identity and expression,
9
+ level of experience, education, socio-economic status, nationality, personal
10
+ appearance, race, religion, or sexual identity and orientation.
11
+
12
+ ## Our Standards
13
+
14
+ Examples of behavior that contributes to creating a positive environment
15
+ include:
16
+
17
+ * Using welcoming and inclusive language
18
+ * Being respectful of differing viewpoints and experiences
19
+ * Gracefully accepting constructive criticism
20
+ * Focusing on what is best for the community
21
+ * Showing empathy towards other community members
22
+
23
+ Examples of unacceptable behavior by participants include:
24
+
25
+ * The use of sexualized language or imagery and unwelcome sexual attention or
26
+ advances
27
+ * Trolling, insulting/derogatory comments, and personal or political attacks
28
+ * Public or private harassment
29
+ * Publishing others' private information, such as a physical or electronic
30
+ address, without explicit permission
31
+ * Other conduct which could reasonably be considered inappropriate in a
32
+ professional setting
33
+
34
+ ## Our Responsibilities
35
+
36
+ Project maintainers are responsible for clarifying the standards of acceptable
37
+ behavior and are expected to take appropriate and fair corrective action in
38
+ response to any instances of unacceptable behavior.
39
+
40
+ Project maintainers have the right and responsibility to remove, edit, or
41
+ reject comments, commits, code, wiki edits, issues, and other contributions
42
+ that are not aligned to this Code of Conduct, or to ban temporarily or
43
+ permanently any contributor for other behaviors that they deem inappropriate,
44
+ threatening, offensive, or harmful.
45
+
46
+ ## Scope
47
+
48
+ This Code of Conduct applies within all project spaces, and it also applies when
49
+ an individual is representing the project or its community in public spaces.
50
+ Examples of representing a project or community include using an official
51
+ project e-mail address, posting via an official social media account, or acting
52
+ as an appointed representative at an online or offline event. Representation of
53
+ a project may be further defined and clarified by project maintainers.
54
+
55
+ This Code of Conduct also applies outside the project spaces when there is a
56
+ reasonable belief that an individual's behavior may have a negative impact on
57
+ the project or its community.
58
+
59
+ ## Enforcement
60
+
61
+ Instances of abusive, harassing, or otherwise unacceptable behavior may be
62
+ reported by contacting the project team at <opensource-conduct@meta.com>. All
63
+ complaints will be reviewed and investigated and will result in a response that
64
+ is deemed necessary and appropriate to the circumstances. The project team is
65
+ obligated to maintain confidentiality with regard to the reporter of an incident.
66
+ Further details of specific enforcement policies may be posted separately.
67
+
68
+ Project maintainers who do not follow or enforce the Code of Conduct in good
69
+ faith may face temporary or permanent repercussions as determined by other
70
+ members of the project's leadership.
71
+
72
+ ## Attribution
73
+
74
+ This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
75
+ available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
76
+
77
+ [homepage]: https://www.contributor-covenant.org
78
+
79
+ For answers to common questions about this code of conduct, see
80
+ https://www.contributor-covenant.org/faq
InfiniDepth/model/block/torchhub/dinov3/CONTRIBUTING.md ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Contributing to DINOv3
2
+ We want to make contributing to this project as easy and transparent as
3
+ possible.
4
+
5
+ ## Pull Requests
6
+ We actively welcome your pull requests.
7
+
8
+ 1. Fork the repo and create your branch from `main`.
9
+ 2. If you've added code that should be tested, add tests.
10
+ 3. If you've changed APIs, update the documentation.
11
+ 4. Ensure the test suite passes.
12
+ 5. Make sure your code lints.
13
+ 6. If you haven't already, complete the Contributor License Agreement ("CLA").
14
+
15
+ ## Contributor License Agreement ("CLA")
16
+ In order to accept your pull request, we need you to submit a CLA. You only need
17
+ to do this once to work on any of Meta's open source projects.
18
+
19
+ Complete your CLA here: <https://code.facebook.com/cla>
20
+
21
+ ## Issues
22
+ We use GitHub issues to track public bugs. Please ensure your description is
23
+ clear and has sufficient instructions to be able to reproduce the issue.
24
+
25
+ Meta has a [bounty program](https://www.facebook.com/whitehat/) for the safe
26
+ disclosure of security bugs. In those cases, please go through the process
27
+ outlined on that page and do not file a public issue.
28
+
29
+ ## License
30
+ By contributing to DINOv3, you agree that your contributions will be licensed
31
+ under the LICENSE.md file in the root directory of this source tree.
InfiniDepth/model/block/torchhub/dinov3/LICENSE.md ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # DINOv3 License
2
+
3
+ *Last Updated: August 19, 2025*
4
+
5
+ **“Agreement”** means the terms and conditions for use, reproduction, distribution and modification of the DINO Materials set forth herein.
6
+
7
+ **“DINO Materials”** means, collectively, Documentation and the models, software and algorithms, including machine-learning model code, trained model weights, inference-enabling code, training-enabling code, fine-tuning enabling code, and other elements of the foregoing distributed by Meta and made available under this Agreement.
8
+
9
+ **“Documentation”** means the specifications, manuals and documentation accompanying
10
+ DINO Materials distributed by Meta.
11
+
12
+ **“Licensee”** or **“you”** means you, or your employer or any other person or entity (if you are entering into this Agreement on such person or entity’s behalf), of the age required under applicable laws, rules or regulations to provide legal consent and that has legal authority to bind your employer or such other person or entity if you are entering in this Agreement on their behalf.
13
+
14
+ **“Meta”** or **“we”** means Meta Platforms Ireland Limited (if you are located in or, if you are an entity, your principal place of business is in the EEA or Switzerland) or Meta Platforms, Inc. (if you are located outside of the EEA or Switzerland).
15
+
16
+ **“Sanctions”** means any economic or trade sanctions or restrictions administered or enforced by the United States (including the Office of Foreign Assets Control of the U.S. Department of the Treasury (“OFAC”), the U.S. Department of State and the U.S. Department of Commerce), the United Nations, the European Union, or the United Kingdom.
17
+
18
+ **“Trade Controls”** means any of the following: Sanctions and applicable export and import controls.
19
+
20
+ By clicking “I Accept” below or by using or distributing any portion or element of the DINO Materials, you agree to be bound by this Agreement.
21
+
22
+ ## 1. License Rights and Redistribution.
23
+
24
+ a. <ins>Grant of Rights</ins>. You are granted a non-exclusive, worldwide, non-transferable and royalty-free limited license under Meta’s intellectual property or other rights owned by Meta embodied in the DINO Materials to use, reproduce, distribute, copy, create derivative works of, and make modifications to the DINO Materials.
25
+
26
+ b. <ins>Redistribution and Use</ins>.
27
+
28
+ i. Distribution of DINO Materials, and any derivative works thereof, are subject to the terms of this Agreement. If you distribute or make the DINO Materials, or any derivative works thereof, available to a third party, you may only do so under the terms of this Agreement and you shall provide a copy of this Agreement with any such DINO Materials.
29
+
30
+ ii. If you submit for publication the results of research you perform on, using, or otherwise in connection with DINO Materials, you must acknowledge the use of DINO Materials in your publication.
31
+
32
+ iii. Your use of the DINO Materials must comply with applicable laws and regulations, including Trade Control Laws and applicable privacy and data protection laws.
33
+
34
+ iv. Your use of the DINO Materials will not involve or encourage others to reverse engineer, decompile or discover the underlying components of the DINO Materials.
35
+
36
+ v. You are not the target of Trade Controls and your use of DINO Materials must comply with Trade Controls. You agree not to use, or permit others to use, DINO Materials for any activities subject to the International Traffic in Arms Regulations (ITAR) or end uses prohibited by Trade Controls, including those related to military or warfare purposes, nuclear industries or applications, espionage, or the development or use of guns or illegal weapons.
37
+
38
+ ## 2. User Support.
39
+
40
+ Your use of the DINO Materials is done at your own discretion; Meta does not process any information nor provide any service in relation to such use. Meta is under no obligation to provide any support services for the DINO Materials. Any support provided is “as is”, “with all faults”, and without warranty of any kind.
41
+
42
+ ## 3. Disclaimer of Warranty.
43
+
44
+ UNLESS REQUIRED BY APPLICABLE LAW, THE DINO MATERIALS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED ON AN “AS IS” BASIS, WITHOUT WARRANTIES OF ANY KIND, AND META DISCLAIMS ALL WARRANTIES OF ANY KIND, BOTH EXPRESS AND IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING OR REDISTRIBUTING THE DINO MATERIALS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR USE OF THE DINO MATERIALS AND ANY OUTPUT AND RESULTS.
45
+
46
+ ## 4. Limitation of Liability.
47
+
48
+ IN NO EVENT WILL META OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS AGREEMENT, FOR ANY LOST PROFITS OR ANY DIRECT OR INDIRECT, SPECIAL, CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN IF META OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING.
49
+
50
+ ## 5. Intellectual Property.
51
+
52
+ a. Subject to Meta’s ownership of DINO Materials and derivatives made by or for Meta, with respect to any derivative works and modifications of the DINO Materials that are made by you, as between you and Meta, you are and will be the owner of such derivative works and modifications.
53
+
54
+ b. If you institute litigation or other proceedings against Meta or any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the DINO Materials, outputs or results, or any portion of any of the foregoing, constitutes infringement of intellectual property or other rights owned or licensable by you, then any licenses granted to you under this Agreement shall terminate as of the date such litigation or claim is filed or instituted. You will indemnify and hold harmless Meta from and against any claim by any third party arising out of or related to your use or distribution of the DINO Materials.
55
+
56
+ ## 6. Term and Termination.
57
+
58
+ The term of this Agreement will commence upon your acceptance of this Agreement or access to the DINO Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein. Meta may terminate this Agreement if you are in breach of any term or condition of this Agreement. Upon termination of this Agreement, you shall delete and cease use of the DINO Materials. Sections 3, 4 and 7 shall survive the termination of this Agreement.
59
+
60
+ ## 7. Governing Law and Jurisdiction.
61
+
62
+ This Agreement will be governed and construed under the laws of the State of California without regard to choice of law principles, and the UN Convention on Contracts for the International Sale of Goods does not apply to this Agreement. The courts of California shall have exclusive jurisdiction of any dispute arising out of this Agreement.
63
+
64
+ ## 8. Modifications and Amendments.
65
+
66
+ Meta may modify this Agreement from time to time; provided that they are similar in spirit to the current version of the Agreement, but may differ in detail to address new problems or concerns. All such changes will be effective immediately. Your continued use of the DINO Materials after any modification to this Agreement constitutes your agreement to such modification. Except as provided in this Agreement, no modification or addition to any provision of this Agreement will be binding unless it is in writing and signed by an authorized representative of both you and Meta.
InfiniDepth/model/block/torchhub/dinov3/MODEL_CARD.md ADDED
@@ -0,0 +1,432 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Model Card for DINOv3
2
+
3
+ DINOv3 is a family of versatile vision foundation models that outperforms the specialized state of the art across a broad range of settings, without fine-tuning. DINOv3 produces high-quality dense features that achieve outstanding performance on various vision tasks, significantly surpassing previous self- and weakly-supervised foundation models.
4
+
5
+ ## Model Details
6
+
7
+ These are Vision Transformer and ConvNeXt models trained following the method described in the DINOv3 paper. 12 models are provided:
8
+
9
+ - 10 models pretrained on web data (LVD-1689M dataset)
10
+ - 1 ViT-7B trained from scratch,
11
+ - 5 ViT-S/S+/B/L/H+ models distilled from the ViT-7B,
12
+ - 4 ConvNeXt-{T/S/B/L} models distilled from the ViT-7B,
13
+ - 2 models pretrained on satellite data (SAT-493M dataset)
14
+ - 1 ViT-7B trained from scratch
15
+ - 1 ViT-L distilled from the ViT-7B
16
+
17
+
18
+ Each Transformer-based model takes an image as input and returns a class token, patch tokens (and register tokens). These models follow a ViT architecture, with a patch size of 16. For a 224x224 image, this results in 1 class token + 4 register tokens + 196 patch tokens = 201 tokens (for DINOv2 with registers this resulted in 1 + 4 + 256 = 261 tokens).
19
+
20
+ The models can accept larger images provided the image shapes are multiples of the patch size (16). If this condition is not verified, the model will crop to the closest smaller multiple of the patch size.
21
+
22
+ ### Model Description
23
+
24
+ - **Developed by:** Meta AI
25
+ - **Model type:** Vision Transformer, ConvNeXt
26
+ - **License:** [DINOv3 License](https://ai.meta.com/resources/models-and-libraries/dinov3-license/)
27
+
28
+ ### Model Sources
29
+
30
+ - **Repository:** [https://github.com/facebookresearch/dinov3](https://github.com/facebookresearch/dinov3)
31
+ - **Paper:** [https://arxiv.org/abs/2508.10104](https://arxiv.org/abs/2508.10104)
32
+
33
+ ## Uses
34
+
35
+ The models are vision backbones providing multi-purpose features for downstream tasks.
36
+
37
+ ### Direct Use
38
+
39
+ The models can be used without fine-tuning, with downstream classifiers as simple as linear layers, to obtain competitive results:
40
+
41
+ - on image classification, using k-NN classifiers on the class token
42
+ - on image classification, with logistic regression classifiers applied on the class token
43
+ - on image classification, with a linear layer applied on the class token and the average of the patch tokens
44
+ - on image retrieval using nearest neighbors
45
+ - on geometric and semantic 3D keypoint correspondances
46
+ - on depth estimation, semantic segmentation, using linear layers
47
+ - on unsupervised object discovery
48
+ - on video segmentation tracking
49
+ - on video classification, using a small 4-layer attentive probe
50
+
51
+ ### Downstream Use
52
+
53
+ While fine-tuning the models can yield some gains, it is recommended to keep this option as a last resort: the frozen features are expected to provide good performance out-of-the-box.
54
+
55
+ ## Bias, Risks, and Limitations
56
+
57
+ Compared to DINOv2 and SEERv2, DINOv3 delivers somewhat consistent performance across income categories on geographical fairness and diversity, although with a notable performance drop in the low-income bucket compared to the highest-income bucket.
58
+
59
+ DINOv3 also achieves relatively good scores across different regions, improving over its predecessor DINOv2. However, a relative difference is still observed between Europe and Africa.
60
+
61
+ ### Recommendations
62
+
63
+ Fine-tuning is expected to increase the biases in the features produced by the model as they will be tuned to the fine-tuning labels.
64
+
65
+ ## How to Get Started with the Model
66
+
67
+ Use the code below to get started with the model.
68
+
69
+ ```python
70
+ import torch
71
+
72
+ model = torch.hub.load(
73
+ repo_or_dir='facebookresearch/dinov3',
74
+ model='<MODEL_NAME>',
75
+ weights='<PATH/OR/URL/TO/CHECKPOINT>',
76
+ )
77
+
78
+ # where MODEL_NAME can be one of:
79
+ # - dinov3_vits16
80
+ # - dinov3_vits16plus
81
+ # - dinov3_vitb16
82
+ # - dinov3_vitl16
83
+ # - dinov3_vith16plus
84
+ # - dinov3_vit7b16
85
+ # - dinov3_convnext_tiny
86
+ # - dinov3_convnext_small
87
+ # - dinov3_convnext_base
88
+ # - dinov3_convnext_large
89
+
90
+ # For instance
91
+ dinov3_vits16 = torch.hub.load(
92
+ repo_or_dir='facebookresearch/dinov3',
93
+ model='dinov3_vits16',
94
+ weights='<PATH/OR/URL/TO/DINOV3/VITS16/LVD1689M/CHECKPOINT>',
95
+ )
96
+ ```
97
+
98
+ ## Training Details
99
+
100
+ ### Training Data
101
+
102
+ - Web dataset (LVD-1689M): a curated dataset of 1,689 millions of images extracted from a large data
103
+ pool of 17 billions web images collected from public posts on Instagram
104
+
105
+ - Satellite dataset (SAT-493M): a dataset of 493 millions of 512x512 images sampled randomly from Maxar RGB ortho-rectified imagery at 0.6 meter resolution
106
+
107
+ ### Training Procedure
108
+
109
+ **Training objective:**
110
+
111
+ - DINO self-distillation loss with multi-crop
112
+ - iBOT masked-image modeling loss
113
+ - KoLeo regularization on [CLS] tokens
114
+ - Gram anchoring
115
+
116
+ - **Training regime:** PyTorch FSDP2 (with bf16 and fp8 matrix multiplications)
117
+
118
+ **Distillation:**
119
+
120
+ - Distillation follows the standard DINOv3 pretraining procedure, except the teacher is a frozen pretrained ViT-7B.
121
+
122
+ ## Evaluation
123
+
124
+ **Results**
125
+
126
+ The reader is referred to the associated paper for details on the evaluation protocols
127
+
128
+ *Results for ViT backbones pretrained (or distilled) on web (LVD-1689M)*
129
+
130
+ <table>
131
+ <tr>
132
+ <th></th>
133
+ <!-- <th></th> -->
134
+ <th colspan="4">Global Tasks</th>
135
+ <th colspan="5">Dense Tasks</th>
136
+ </tr>
137
+ <tr>
138
+ <th>Model</th>
139
+ <!-- <th>Dataset</th> -->
140
+ <th>IN-ReaL</th>
141
+ <th>IN-R</th>
142
+ <th>Obj.Net</th>
143
+ <th>Ox.-H</th>
144
+ <th>ADE20k</th>
145
+ <th>NYU↓</th>
146
+ <th>DAVIS</th>
147
+ <th>NAVI</th>
148
+ <th>SPair</th>
149
+ </tr>
150
+ <tr>
151
+ <td>DINOv3 ViT-S/16</td>
152
+ <!-- <td>LVD-1689M</td> -->
153
+ <td align="right">87.0</td>
154
+ <td align="right">60.4</td>
155
+ <td align="right">50.9</td>
156
+ <td align="right">49.5</td>
157
+ <td align="right">47.0</td>
158
+ <td align="right">0.403</td>
159
+ <td align="right">72.7</td>
160
+ <td align="right">56.3</td>
161
+ <td align="right">50.4</td>
162
+ </tr>
163
+ <tr>
164
+ <td>DINOv3 ViT-S+/16</td>
165
+ <!-- <td>LVD-1689M</td> -->
166
+ <td align="right">88.0</td>
167
+ <td align="right">68.8</td>
168
+ <td align="right">54.6</td>
169
+ <td align="right">50.0</td>
170
+ <td align="right">48.8</td>
171
+ <td align="right">0.399</td>
172
+ <td align="right">75.5</td>
173
+ <td align="right">57.1</td>
174
+ <td align="right">55.2</td>
175
+ </tr>
176
+ <tr>
177
+ <td>DINOv3 ViT-B/16</td>
178
+ <!-- <td>LVD-1689M</td> -->
179
+ <td align="right">89.3</td>
180
+ <td align="right">76.7</td>
181
+ <td align="right">64.1</td>
182
+ <td align="right">58.5</td>
183
+ <td align="right">51.8</td>
184
+ <td align="right">0.373</td>
185
+ <td align="right">77.2</td>
186
+ <td align="right">58.8</td>
187
+ <td align="right">57.2</td>
188
+ </tr>
189
+ <tr>
190
+ <td>DINOv3 ViT-L/16</td>
191
+ <!-- <td>LVD-1689M</td> -->
192
+ <td align="right">90.2</td>
193
+ <td align="right">88.1</td>
194
+ <td align="right">74.8</td>
195
+ <td align="right">63.1</td>
196
+ <td align="right">54.9</td>
197
+ <td align="right">0.352</td>
198
+ <td align="right">79.9</td>
199
+ <td align="right">62.3</td>
200
+ <td align="right">61.3</td>
201
+ </tr>
202
+ <tr>
203
+ <td>DINOv3 ViT-H+/16</td>
204
+ <!-- <td>LVD-1689M</td> -->
205
+ <td align="right">90.3</td>
206
+ <td align="right">90.0</td>
207
+ <td align="right">78.6</td>
208
+ <td align="right">64.5</td>
209
+ <td align="right">54.8</td>
210
+ <td align="right">0.352</td>
211
+ <td align="right">79.3</td>
212
+ <td align="right">63.3</td>
213
+ <td align="right">56.3</td>
214
+ </tr>
215
+ <tr>
216
+ <td>DINOv3 ViT-7B/16</td>
217
+ <!-- <td>LVD-1689M</td> -->
218
+ <td align="right">90.4</td>
219
+ <td align="right">91.1</td>
220
+ <td align="right">91.1</td>
221
+ <td align="right">72.8</td>
222
+ <td align="right">55.9</td>
223
+ <td align="right">0.309</td>
224
+ <td align="right">79.7</td>
225
+ <td align="right">64.4</td>
226
+ <td align="right">58.7</td>
227
+ </tr>
228
+ </table>
229
+
230
+ *Results for ConvNeXt backbones distilled on web (LVD-1689M)*
231
+
232
+ <table>
233
+ <tr>
234
+ <th></th>
235
+ <th colspan="6">Global Tasks</th>
236
+ <th colspan="2">Dense Tasks</th>
237
+ </tr>
238
+ <tr>
239
+ <th>Model</th>
240
+ <th colspan="2">IN-ReaL</th>
241
+ <th colspan="2">IN-R</th>
242
+ <th colspan="2">Obj.Net</th>
243
+ <th>ADE20k</th>
244
+ <th>NYU↓</th>
245
+ </tr>
246
+ <tr>
247
+ <td></th>
248
+ <td>@256px</td>
249
+ <td>@512px</td>
250
+ <td>@256px</td>
251
+ <td>@512px</td>
252
+ <td>@256px</td>
253
+ <td>@512px</td>
254
+ <td colspan="2"></td>
255
+ </tr>
256
+ <tr>
257
+ <td>DINOv3 ConvNeXt Tiny</td>
258
+ <td align="right">86.6</td>
259
+ <td align="right">87.7</td>
260
+ <td align="right">73.7</td>
261
+ <td align="right">74.1</td>
262
+ <td align="right">52.6</td>
263
+ <td align="right">58.7</td>
264
+ <td align="right">42.7</td>
265
+ <td align="right">0.448</td>
266
+ </tr>
267
+ <tr>
268
+ <td>DINOv3 ConvNeXt Small</td>
269
+ <td align="right">87.9</td>
270
+ <td align="right">88.7</td>
271
+ <td align="right">73.7</td>
272
+ <td align="right">74.1</td>
273
+ <td align="right">52.6</td>
274
+ <td align="right">58.7</td>
275
+ <td align="right">44.8</td>
276
+ <td align="right">0.432</td>
277
+ </tr>
278
+ <tr>
279
+ <td>DINOv3 ConvNeXt Base</td>
280
+ <td align="right">88.5</td>
281
+ <td align="right">89.2</td>
282
+ <td align="right">77.2</td>
283
+ <td align="right">78.2</td>
284
+ <td align="right">56.2</td>
285
+ <td align="right">61.3</td>
286
+ <td align="right">46.3</td>
287
+ <td align="right">0.420</td>
288
+ </tr>
289
+ <tr>
290
+ <td>DINOv3 ConvNeXt Large</td>
291
+ <td align="right">88.9</td>
292
+ <td align="right">89.4</td>
293
+ <td align="right">81.3</td>
294
+ <td align="right">82.4</td>
295
+ <td align="right">59.3</td>
296
+ <td align="right">65.2</td>
297
+ <td align="right">47.8</td>
298
+ <td align="right">0.403</td>
299
+ </tr>
300
+ </table>
301
+
302
+ *Results for ViT backbones pretrained (or distilled) on satellite (SAT-493M)*
303
+
304
+ <table>
305
+ <tr>
306
+ <th></th>
307
+ <th colspan="7">(GEO-Bench) Classification</th>
308
+ </tr>
309
+ <tr>
310
+ <th>Model</ht>
311
+ <th>m-BEnet</th>
312
+ <th>m-brick-kiln
313
+ <th>m-eurosat</th>
314
+ <th>m-forestnet</th>
315
+ <th>m-pv4ger</th>
316
+ <th>m-so2sat</th>
317
+ <th>mean</th>
318
+ </tr>
319
+ <tr>
320
+ <td>DINOv3 ViT-L/16</td>
321
+ <td>73.0</td>
322
+ <td>96.5</td>
323
+ <td>94.1</td>
324
+ <td>60.6</td>
325
+ <td>96.0</td>
326
+ <td>57.4</td>
327
+ <td>79.6</td>
328
+ </tr>
329
+ <tr>
330
+ <td>DINOv3 ViT-7B/16</td>
331
+ <td>74.0</td>
332
+ <td>97.2</td>
333
+ <td>94.8</td>
334
+ <td>62.3</td>
335
+ <td>96.1</td>
336
+ <td>62.1</td>
337
+ <td>81.1</td>
338
+ </tr>
339
+ <tr>
340
+ <th></th>
341
+ <th colspan="7">(GEO-Bench) Segmentation</th>
342
+ </tr>
343
+ <tr>
344
+ <th>Model</th>
345
+ <th>m-cashew</th>
346
+ <th>m-chesapeake</th>
347
+ <th>m-NeonTree</th>
348
+ <th>m-nz-cattle</th>
349
+ <th>m-pv4ger-seg</th>
350
+ <th>m-SA-crop</th>
351
+ <th>mean</th>
352
+ </tr>
353
+ <tr>
354
+ <td>DINOv3 ViT-L/16</td>
355
+ <td>94.2</td>
356
+ <td>75.6</td>
357
+ <td>61.8</td>
358
+ <td>83.7</td>
359
+ <td>95.2</td>
360
+ <td>36.8</td>
361
+ <td>74.5</td>
362
+ </tr>
363
+ <tr>
364
+ <td>DINOv3 ViT-7B/16</td>
365
+ <td>94.1</td>
366
+ <td>76.6</td>
367
+ <td>62.6</td>
368
+ <td>83.4</td>
369
+ <td>95.5</td>
370
+ <td>37.6</td>
371
+ <td>75.0</td>
372
+ </tr>
373
+ </table>
374
+
375
+
376
+ ## Environmental Impact
377
+
378
+ - **Hardware Type:** Nvidia H100
379
+ - **Hours used:** 61,440 hours for ViT-7B model training
380
+ - **Cloud Provider:** Private infrastructure
381
+ - **Compute Region:** USA
382
+ - **Carbon Emitted:** 18t CO2eq
383
+
384
+ ## Technical Specifications
385
+
386
+ ### Model Architecture and Objective
387
+
388
+ Vision Transformer models:
389
+
390
+ - ViT-S (21M parameters): patch size 16, embedding dimension 384, 4 register tokens, 6 heads, MLP FFN, RoPE
391
+ - ViT-S+ (29M parameters): patch size 16, embedding dimension 384, 4 register tokens, 6 heads, SwiGLU FFN, RoPE
392
+ - ViT-B (86M parameters): patch size 16, embedding dimension 768, 4 register tokens, 12 heads, MLP FFN, RoPE
393
+ - ViT-L (300M parameters): patch size 16, embedding dimension 1024, 4 register tokens, 16 heads, MLP FFN, RoPE
394
+ - ViT-H+ (840M parameters): patch size 16, embedding dimension 1280, 4 register tokens, 20 heads, SwiGLU FFN, RoPE
395
+ - ViT-7B (6716M parameters): patch size 16, embedding dimension 4096, 4 register tokens, 32 heads, SwiGLU FFN, RoPE
396
+
397
+ ConvNeXt models:
398
+
399
+ - ConvNeXt Tiny (29M parameters)
400
+ - ConvNeXt Small (50M parameters)
401
+ - ConvNeXt Base (89M parameters)
402
+ - ConvNeXt Large (198M parameters)
403
+
404
+ ### Compute Infrastructure
405
+
406
+ #### Hardware
407
+
408
+ Nvidia H100 GPUs
409
+
410
+ #### Software
411
+
412
+ PyTorch 2.7
413
+
414
+ ## More Information
415
+
416
+ See the [blog post](https://ai.meta.com/blog/dinov3-self-supervised-vision-model/) and the associated [website](https://ai.meta.com/dinov3/).
417
+
418
+ ## Citation
419
+
420
+ **BibTeX**
421
+
422
+ ```
423
+ @misc{simeoni2025dinov3,
424
+ title={{DINOv3}},
425
+ author={Sim{\'e}oni, Oriane and Vo, Huy V. and Seitzer, Maximilian and Baldassarre, Federico and Oquab, Maxime and Jose, Cijo and Khalidov, Vasil and Szafraniec, Marc and Yi, Seungeun and Ramamonjisoa, Micha{\"e}l and Massa, Francisco and Haziza, Daniel and Wehrstedt, Luca and Wang, Jianyuan and Darcet, Timoth{\'e}e and Moutakanni, Th{\'e}o and Sentana, Leonel and Roberts, Claire and Vedaldi, Andrea and Tolan, Jamie and Brandt, John and Couprie, Camille and Mairal, Julien and J{\'e}gou, Herv{\'e} and Labatut, Patrick and Bojanowski, Piotr},
426
+ year={2025},
427
+ eprint={2508.10104},
428
+ archivePrefix={arXiv},
429
+ primaryClass={cs.CV},
430
+ url={https://arxiv.org/abs/2508.10104},
431
+ }
432
+ ```
InfiniDepth/model/block/torchhub/dinov3/README.md ADDED
@@ -0,0 +1,734 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 🆕 [2025-08-14] :fire: DINOv3 backbones are now available in [Hugging Face Hub](https://huggingface.co/collections/facebook/dinov3-68924841bd6b561778e31009) and [supported](https://huggingface.co/docs/transformers/model_doc/dinov3) by the Hugging Face [Transformers](https://huggingface.co/docs/transformers/index) library
2
+
3
+ # DINOv3 🦖🦖🦖
4
+
5
+ **[Meta AI Research, FAIR](https://ai.meta.com/research/)**
6
+
7
+ Oriane Siméoni, Huy V. Vo, Maximilian Seitzer, Federico Baldassarre, Maxime Oquab, <br/>
8
+ Cijo Jose, Vasil Khalidov, Marc Szafraniec, Seungeun Yi, Michaël Ramamonjisoa, <br/>
9
+ Francisco Massa, Daniel Haziza, Luca Wehrstedt, Jianyuan Wang, <br/>
10
+ Timothée Darcet, Théo Moutakanni, Leonel Sentana, Claire Roberts, <br/>
11
+ Andrea Vedaldi, Jamie Tolan, John Brandt, Camille Couprie, <br/>
12
+ Julien Mairal, Hervé Jégou, Patrick Labatut, Piotr Bojanowski
13
+
14
+ [ :scroll: [`Paper`](https://arxiv.org/abs/2508.10104)] [ :newspaper: [`Blog`](https://ai.meta.com/blog/dinov3-self-supervised-vision-model/)] [ :globe_with_meridians: [`Website`](https://ai.meta.com/dinov3/)] [ :book: [`BibTeX`](#citing-dinov3)]
15
+
16
+ Reference PyTorch implementation and models for DINOv3. For details, see the **[DINOv3](https://arxiv.org/abs/2508.10104)** paper.
17
+
18
+ ## Overview
19
+
20
+ <div align="center">
21
+ <img width="1364" height="1024" alt="market" src="https://github.com/user-attachments/assets/1411f491-988e-49cb-95ae-d03fe6e3c268" />
22
+
23
+ <i></em><b>High-resolution dense features.</b><br/>We visualize the cosine similarity maps obtained with DINOv3 output features<br/> between the patches marked with a red cross and all other patches.</i>
24
+ </div>
25
+
26
+ <br/>
27
+
28
+ An extended family of versatile vision foundation models producing high-quality dense features and achieving outstanding performance on various vision tasks including outperforming the specialized state of the art across a broad range of settings, without fine-tuning
29
+
30
+ ## Pretrained models
31
+
32
+ :information_source: Please follow the link provided below to get access to all the model weights: once accepted, an e-mail will be sent with the complete list of URLs pointing to all the available model weights (both backbones and adapters). These URLs can then be used to either:
33
+ - download the model or adapter weights to a local filesystem and point `torch.hub.load()` to these local weights via the `weights` or `backbone_weights` parameters, or
34
+ - directly invoke `torch.hub.load()` to download and load a backbone or an adapter from its URL via also the `weights` or `backbone_weights` parameters.
35
+
36
+ See the example code snippets below.
37
+
38
+ :warning: Please use `wget` instead of a web browser to download the weights.
39
+
40
+ ViT models pretrained on web dataset (LVD-1689M):
41
+ <table style="margin: auto">
42
+ <thead>
43
+ <tr>
44
+ <th>Model</th>
45
+ <th>Parameters</th>
46
+ <th>Pretraining<br/>Dataset</th>
47
+ <th>Download</th>
48
+ </tr>
49
+ </thead>
50
+ <tbody>
51
+ <tr>
52
+ <td>ViT-S/16 distilled </td>
53
+ <td align="right">21M</td>
54
+ <td align="center">LVD-1689M</td>
55
+ <td align="center"><a href="https://ai.meta.com/resources/models-and-libraries/dinov3-downloads/">[link]</a></td>
56
+ </tr>
57
+ <tr>
58
+ <td>ViT-S+/16 distilled</td>
59
+ <td align="right">29M</td>
60
+ <td align="center">LVD-1689M</td>
61
+ <td align="center"><a href="https://ai.meta.com/resources/models-and-libraries/dinov3-downloads/">[link]</a></td>
62
+ </tr>
63
+ <tr>
64
+ <td>ViT-B/16 distilled</td>
65
+ <td align="right">86M</td>
66
+ <td align="center">LVD-1689M</td>
67
+ <td align="center"><a href="https://ai.meta.com/resources/models-and-libraries/dinov3-downloads/">[link]</a></td>
68
+ </tr>
69
+ <tr>
70
+ <td>ViT-L/16 distilled</td>
71
+ <td align="right">300M</td>
72
+ <td align="center">LVD-1689M</td>
73
+ <td align="center"><a href="https://ai.meta.com/resources/models-and-libraries/dinov3-downloads/">[link]</a></td>
74
+ </tr>
75
+ <tr>
76
+ <td>ViT-H+/16 distilled</td>
77
+ <td align="right">840M</td>
78
+ <td align="center">LVD-1689M</td>
79
+ <td align="center"><a href="https://ai.meta.com/resources/models-and-libraries/dinov3-downloads/">[link]</a></td>
80
+ </tr>
81
+ <tr>
82
+ <td>ViT-7B/16</td>
83
+ <td align="right">6,716M</td>
84
+ <td align="center">LVD-1689M</td>
85
+ <td align="center"><a href="https://ai.meta.com/resources/models-and-libraries/dinov3-downloads/">[link]</a></td>
86
+ </tr>
87
+ </tbody>
88
+ </table>
89
+
90
+ ConvNeXt models pretrained on web dataset (LVD-1689M):
91
+ <table style="margin: auto">
92
+ <thead>
93
+ <tr>
94
+ <th>Model</th>
95
+ <th>Parameters</th>
96
+ <th>Pretraining<br/>Dataset</th>
97
+ <th>Download</th>
98
+ </tr>
99
+ </thead>
100
+ <tbody>
101
+ <tr>
102
+ <td>ConvNeXt Tiny</td>
103
+ <td align="right">29M</td>
104
+ <td align="center">LVD-1689M</td>
105
+ <td align="center"><a href="https://ai.meta.com/resources/models-and-libraries/dinov3-downloads/">[link]</a></td>
106
+ </tr>
107
+ <tr>
108
+ <td>ConvNeXt Small</td>
109
+ <td align="right">50M</td>
110
+ <td align="center">LVD-1689M</td>
111
+ <td align="center"><a href="https://ai.meta.com/resources/models-and-libraries/dinov3-downloads/">[link]</a></td>
112
+ </tr>
113
+ <tr>
114
+ <td>ConvNeXt Base</td>
115
+ <td align="right">89M</td>
116
+ <td align="center">LVD-1689M</td>
117
+ <td align="center"><a href="https://ai.meta.com/resources/models-and-libraries/dinov3-downloads/">[link]</a></td>
118
+ </tr>
119
+ <tr>
120
+ <td>ConvNeXt Large</td>
121
+ <td align="right">198M</td>
122
+ <td align="center">LVD-1689M</td>
123
+ <td align="center"><a href="https://ai.meta.com/resources/models-and-libraries/dinov3-downloads/">[link]</a></td>
124
+ </tr>
125
+ </tbody>
126
+ </table>
127
+
128
+ ViT models pretrained on satellite dataset (SAT-493M):
129
+ <table style="margin: auto">
130
+ <thead>
131
+ <tr>
132
+ <th>Model</th>
133
+ <th>Parameters</th>
134
+ <th>Pretraining<br/>Dataset</th>
135
+ <th>Download</th>
136
+ </tr>
137
+ </thead>
138
+ <tbody>
139
+ <tr>
140
+ <td>ViT-L/16 distilled</td>
141
+ <td align="right">300M</td>
142
+ <td align="center">SAT-493M</td>
143
+ <td align="center"><a href="https://ai.meta.com/resources/models-and-libraries/dinov3-downloads/">[link]</a></td>
144
+ </tr>
145
+ <tr>
146
+ <td>ViT-7B/16</td>
147
+ <td align="right">6,716M</td>
148
+ <td align="center">SAT-493M</td>
149
+ <td align="center"><a href="https://ai.meta.com/resources/models-and-libraries/dinov3-downloads/">[link]</a></td>
150
+ </tr>
151
+ </tbody>
152
+ </table>
153
+
154
+
155
+ ### Pretrained backbones (via PyTorch [Hub](https://docs.pytorch.org/docs/stable/hub.html))
156
+
157
+ Please follow the instructions [here](https://pytorch.org/get-started/locally/) to install PyTorch (the only required dependency for loading the model). Installing PyTorch with CUDA support is strongly recommended.
158
+
159
+ ```python
160
+ import torch
161
+
162
+ REPO_DIR = <PATH/TO/A/LOCAL/DIRECTORY/WHERE/THE/DINOV3/REPO/WAS/CLONED>
163
+
164
+ # DINOv3 ViT models pretrained on web images
165
+ dinov3_vits16 = torch.hub.load(REPO_DIR, 'dinov3_vits16', source='local', weights=<CHECKPOINT/URL/OR/PATH>)
166
+ dinov3_vits16plus = torch.hub.load(REPO_DIR, 'dinov3_vits16plus', source='local', weights=<CHECKPOINT/URL/OR/PATH>)
167
+ dinov3_vitb16 = torch.hub.load(REPO_DIR, 'dinov3_vitb16', source='local', weights=<CHECKPOINT/URL/OR/PATH>)
168
+ dinov3_vitl16 = torch.hub.load(REPO_DIR, 'dinov3_vitl16', source='local', weights=<CHECKPOINT/URL/OR/PATH>)
169
+ dinov3_vith16plus = torch.hub.load(REPO_DIR, 'dinov3_vith16plus', source='local', weights=<CHECKPOINT/URL/OR/PATH>)
170
+ dinov3_vit7b16 = torch.hub.load(REPO_DIR, 'dinov3_vit7b16', source='local', weights=<CHECKPOINT/URL/OR/PATH>)
171
+
172
+ # DINOv3 ConvNeXt models pretrained on web images
173
+ dinov3_convnext_tiny = torch.hub.load(REPO_DIR, 'dinov3_convnext_tiny', source='local', weights=<CHECKPOINT/URL/OR/PATH>)
174
+ dinov3_convnext_small = torch.hub.load(REPO_DIR, 'dinov3_convnext_small', source='local', weights=<CHECKPOINT/URL/OR/PATH>)
175
+ dinov3_convnext_base = torch.hub.load(REPO_DIR, 'dinov3_convnext_base', source='local', weights=<CHECKPOINT/URL/OR/PATH>)
176
+ dinov3_convnext_large = torch.hub.load(REPO_DIR, 'dinov3_convnext_large', source='local', weights=<CHECKPOINT/URL/OR/PATH>)
177
+
178
+ # DINOv3 ViT models pretrained on satellite imagery
179
+ dinov3_vitl16 = torch.hub.load(REPO_DIR, 'dinov3_vitl16', source='local', weights=<CHECKPOINT/URL/OR/PATH>)
180
+ dinov3_vit7b16 = torch.hub.load(REPO_DIR, 'dinov3_vit7b16', source='local', weights=<CHECKPOINT/URL/OR/PATH>)
181
+ ```
182
+
183
+ ### Pretrained backbones (via Hugging Face [Transformers](https://huggingface.co/docs/transformers/))
184
+
185
+ All the backbones are available in the the [DINOv3](https://huggingface.co/collections/facebook/dinov3-68924841bd6b561778e31009) collection on Hugging Face Hub and supported via the Hugging Face [Transformers](https://huggingface.co/docs/transformers/index) library. Please refer to the corresponding documentation for usage, but below is a short example that demonstrates how to obtain an image embedding with either [Pipeline] or the [AutoModel] class.
186
+
187
+ ```python
188
+ from transformers import pipeline
189
+ from transformers.image_utils import load_image
190
+
191
+ url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"
192
+ image = load_image(url)
193
+
194
+ feature_extractor = pipeline(
195
+ model="facebook/dinov3-convnext-tiny-pretrain-lvd1689m",
196
+ task="image-feature-extraction",
197
+ )
198
+ features = feature_extractor(image)
199
+ ```
200
+
201
+ ```python
202
+ import torch
203
+ from transformers import AutoImageProcessor, AutoModel
204
+ from transformers.image_utils import load_image
205
+
206
+ url = "http://images.cocodataset.org/val2017/000000039769.jpg"
207
+ image = load_image(url)
208
+
209
+ pretrained_model_name = "facebook/dinov3-convnext-tiny-pretrain-lvd1689m"
210
+ processor = AutoImageProcessor.from_pretrained(pretrained_model_name)
211
+ model = AutoModel.from_pretrained(
212
+ pretrained_model_name,
213
+ device_map="auto",
214
+ )
215
+
216
+ inputs = processor(images=image, return_tensors="pt").to(model.device)
217
+ with torch.inference_mode():
218
+ outputs = model(**inputs)
219
+
220
+ pooled_output = outputs.pooler_output
221
+ print("Pooled output shape:", pooled_output.shape)
222
+ ```
223
+
224
+ where `model` and `pretrained_model_name` above can be one of:
225
+ - `facebook/dinov3-vits16-pretrain-lvd1689m`
226
+ - `facebook/dinov3-vits16plus-pretrain-lvd1689m`
227
+ - `facebook/dinov3-vitb16-pretrain-lvd1689m`
228
+ - `facebook/dinov3-vitl16-pretrain-lvd1689m`
229
+ - `facebook/dinov3-vith16plus-pretrain-lvd1689m`
230
+ - `facebook/dinov3-vit7b16-pretrain-lvd1689m`
231
+ - `facebook/dinov3-convnext-base-pretrain-lvd1689m`
232
+ - `facebook/dinov3-convnext-large-pretrain-lvd1689m`
233
+ - `facebook/dinov3-convnext-small-pretrain-lvd1689m`
234
+ - `facebook/dinov3-convnext-tiny-pretrain-lvd1689m`
235
+ - `facebook/dinov3-vitl16-pretrain-sat493m`
236
+ - `facebook/dinov3-vit7b16-pretrain-sat493m`
237
+
238
+ ### Image transforms
239
+
240
+ For models using the LVD-1689M weights (pretrained on web images), please use the following transform (standard ImageNet evaluation transform):
241
+
242
+ ```python
243
+ import torchvision
244
+
245
+ def make_transform(resize_size: int = 224):
246
+ to_tensor = transforms.ToTensor()
247
+ resize = transforms.Resize((resize_size, resize_size), antialias=True)
248
+ normalize = transforms.Normalize(
249
+ mean=(0.485, 0.456, 0.406),
250
+ std=(0.229, 0.224, 0.225),
251
+ )
252
+ return transforms.Compose([to_tensor, resize, normalize])
253
+ ```
254
+
255
+
256
+ For models using the SAT-493M weights (pretrained on satellite imagery), please use the following transform:
257
+
258
+
259
+ ```python
260
+ import torchvision
261
+
262
+ def make_transform(resize_size: int = 224):
263
+ to_tensor = transforms.ToTensor()
264
+ resize = transforms.Resize((resize_size, resize_size), antialias=True)
265
+ normalize = transforms.Normalize(
266
+ mean=(0.430, 0.411, 0.296),
267
+ std=(0.213, 0.156, 0.143),
268
+ )
269
+ return transforms.Compose([to_tensor, resize, normalize])
270
+ ```
271
+
272
+ ### Pretrained heads - Image classification
273
+
274
+ <table style="margin: auto">
275
+ <thead>
276
+ <tr>
277
+ <th>Backbone</th>
278
+ <th>Pretraining<br/>Dataset</th>
279
+ <th>Head<br/>Dataset</th>
280
+ <th>Download</th>
281
+ </tr>
282
+ </thead>
283
+ <tbody>
284
+ <tr>
285
+ <td>ViT-7B/16</td>
286
+ <td align="center">LVD-1689M</td>
287
+ <td align="center">ImageNet</td>
288
+ <td align="center"><a href="https://ai.meta.com/resources/models-and-libraries/dinov3-downloads/">[link]</a></td>
289
+ </tr>
290
+ </tbody>
291
+ </table>
292
+
293
+
294
+ The (full) classifier models can be loaded via PyTorch Hub:
295
+
296
+ ```python
297
+ import torch
298
+
299
+ # DINOv3
300
+ dinov3_vit7b16_lc = torch.hub.load(REPO_DIR, 'dinov3_vit7b16_lc', source="local", weights=<DEPTHER/CHECKPOINT/URL/OR/PATH>, backbone_weights=<BACKBONE/CHECKPOINT/URL/OR/PATH>)
301
+
302
+ ```
303
+
304
+ ### Pretrained heads - Depther trained on SYNTHMIX dataset
305
+
306
+ <table style="margin: auto">
307
+ <thead>
308
+ <tr>
309
+ <th>Backbone</th>
310
+ <th>Pretraining<br/>Dataset</th>
311
+ <th>Head<br/>Dataset</th>
312
+ <th>Download</th>
313
+ </tr>
314
+ </thead>
315
+ <tbody>
316
+ <tr>
317
+ <td>ViT-7B/16</td>
318
+ <td align="center">LVD-1689M</td>
319
+ <td align="center">SYNTHMIX</td>
320
+ <td align="center"><a href="https://ai.meta.com/resources/models-and-libraries/dinov3-downloads/">[link]</a></td>
321
+ </tr>
322
+ </tbody>
323
+ </table>
324
+
325
+
326
+ ```python
327
+ depther = torch.hub.load(REPO_DIR, 'dinov3_vit7b16_dd', source="local", weights=<DEPTHER/CHECKPOINT/URL/OR/PATH>, backbone_weights=<BACKBONE/CHECKPOINT/URL/OR/PATH>)
328
+ ```
329
+
330
+ Full example code of depther on an image
331
+
332
+ ```python
333
+ from PIL import Image
334
+ import torch
335
+ from torchvision import transforms
336
+ import matplotlib.pyplot as plt
337
+ from matplotlib import colormaps
338
+
339
+ def get_img():
340
+ import requests
341
+ url = "http://images.cocodataset.org/val2017/000000039769.jpg"
342
+ image = Image.open(requests.get(url, stream=True).raw).convert("RGB")
343
+ return image
344
+
345
+ def make_transform(resize_size: int | list[int] = 768):
346
+ to_tensor = transforms.ToTensor()
347
+ resize = transforms.Resize((resize_size, resize_size), antialias=True)
348
+ normalize = transforms.Normalize(
349
+ mean=(0.485, 0.456, 0.406),
350
+ std=(0.229, 0.224, 0.225),
351
+ )
352
+ return transforms.Compose([to_tensor, resize, normalize])
353
+
354
+ depther = torch.hub.load(REPO_DIR, 'dinov3_vit7b16_dd', source="local", weights=<DEPTHER/CHECKPOINT/URL/OR/PATH>, backbone_weights=<BACKBONE/CHECKPOINT/URL/OR/PATH>)
355
+
356
+ img_size = 1024
357
+ img = get_img()
358
+ transform = make_transform(img_size)
359
+ with torch.inference_mode():
360
+ with torch.autocast('cuda', dtype=torch.bfloat16):
361
+ batch_img = transform(img)[None]
362
+ batch_img = batch_img
363
+ depths = depther(batch_img)
364
+
365
+ plt.figure(figsize=(12, 6))
366
+ plt.subplot(121)
367
+ plt.imshow(img)
368
+ plt.axis("off")
369
+ plt.subplot(122)
370
+ plt.imshow(depths[0,0].cpu(), cmap=colormaps["Spectral"])
371
+ plt.axis("off")
372
+
373
+ ```
374
+
375
+ ### Pretrained heads - Detector trained on COCO2017 dataset
376
+
377
+ <table style="margin: auto">
378
+ <thead>
379
+ <tr>
380
+ <th>Backbone</th>
381
+ <th>Pretraining<br/>Dataset</th>
382
+ <th>Head<br/>Dataset</th>
383
+ <th>Download</th>
384
+ </tr>
385
+ </thead>
386
+ <tbody>
387
+ <tr>
388
+ <td>ViT-7B/16</td>
389
+ <td align="center">LVD-1689M</td>
390
+ <td align="center">COCO2017</td>
391
+ <td align="center"><a href="https://ai.meta.com/resources/models-and-libraries/dinov3-downloads/">[link]</a></td>
392
+ </tr>
393
+ </tbody>
394
+ </table>
395
+
396
+
397
+ ```python
398
+ detector = torch.hub.load(REPO_DIR, 'dinov3_vit7b16_de', source="local", weights=<DETECTOR/CHECKPOINT/URL/OR/PATH>, backbone_weights=<BACKBONE/CHECKPOINT/URL/OR/PATH>)
399
+ ```
400
+
401
+ ### Pretrained heads - Segmentor trained on ADE20K dataset
402
+
403
+ <table style="margin: auto">
404
+ <thead>
405
+ <tr>
406
+ <th>Backbone</th>
407
+ <th>Pretraining<br/>Dataset</th>
408
+ <th>Head<br/>Dataset</th>
409
+ <th>Download</th>
410
+ </tr>
411
+ </thead>
412
+ <tbody>
413
+ <tr>
414
+ <td>ViT-7B/16</td>
415
+ <td align="center">LVD-1689M</td>
416
+ <td align="center">ADE20K</td>
417
+ <td align="center"><a href="https://ai.meta.com/resources/models-and-libraries/dinov3-downloads/">[link]</a></td>
418
+ </tr>
419
+ </tbody>
420
+ </table>
421
+
422
+ ```python
423
+ segmentor = torch.hub.load(REPO_DIR, 'dinov3_vit7b16_ms', source="local", weights=<SEGMENTOR/CHECKPOINT/URL/OR/PATH>, backbone_weights=<BACKBONE/CHECKPOINT/URL/OR/PATH>)
424
+ ```
425
+
426
+ Full example code of segmentator on an image
427
+
428
+ ```python
429
+ import sys
430
+ sys.path.append(REPO_DIR)
431
+
432
+ from PIL import Image
433
+ import torch
434
+ from torchvision import transforms
435
+ import matplotlib.pyplot as plt
436
+ from matplotlib import colormaps
437
+ from functools import partial
438
+ from dinov3.eval.segmentation.inference import make_inference
439
+
440
+
441
+ def get_img():
442
+ import requests
443
+ url = "http://images.cocodataset.org/val2017/000000039769.jpg"
444
+ image = Image.open(requests.get(url, stream=True).raw).convert("RGB")
445
+ return image
446
+
447
+ def make_transform(resize_size: int | list[int] = 768):
448
+ to_tensor = transforms.ToTensor()
449
+ resize = transforms.Resize((resize_size, resize_size), antialias=True)
450
+ normalize = transforms.Normalize(
451
+ mean=(0.485, 0.456, 0.406),
452
+ std=(0.229, 0.224, 0.225),
453
+ )
454
+ return transforms.Compose([to_tensor, resize, normalize])
455
+
456
+ segmentor = torch.hub.load(REPO_DIR, 'dinov3_vit7b16_ms', source="local", weights=<SEGMENTOR/CHECKPOINT/URL/OR/PATH>, backbone_weights=<BACKBONE/CHECKPOINT/URL/OR/PATH>)
457
+
458
+ img_size = 896
459
+ img = get_img()
460
+ transform = make_transform(img_size)
461
+ with torch.inference_mode():
462
+ with torch.autocast('cuda', dtype=torch.bfloat16):
463
+ batch_img = transform(img)[None]
464
+ pred_vit7b = segmentor(batch_img) # raw predictions
465
+ # actual segmentation map
466
+ segmentation_map_vit7b = make_inference(
467
+ batch_img,
468
+ segmentor,
469
+ inference_mode="slide",
470
+ decoder_head_type="m2f",
471
+ rescale_to=(img.size[-1], img.size[-2]),
472
+ n_output_channels=150,
473
+ crop_size=(img_size, img_size),
474
+ stride=(img_size, img_size),
475
+ output_activation=partial(torch.nn.functional.softmax, dim=1),
476
+ ).argmax(dim=1, keepdim=True)
477
+ plt.figure(figsize=(12, 6))
478
+ plt.subplot(121)
479
+ plt.imshow(img)
480
+ plt.axis("off")
481
+ plt.subplot(122)
482
+ plt.imshow(segmentation_map_vit7b[0,0].cpu(), cmap=colormaps["Spectral"])
483
+ plt.axis("off")
484
+ ```
485
+
486
+
487
+
488
+
489
+ ### Pretrained heads - Zero-shot tasks with `dino.txt`
490
+
491
+ <table style="margin: auto">
492
+ <thead>
493
+ <tr>
494
+ <th rowspan="2">Backbone</th>
495
+ <th>Download</th>
496
+ </tr>
497
+ </thead>
498
+ <tbody>
499
+ <tr>
500
+ <td>ViT-L/16 distilled</td>
501
+ <td align="center">
502
+ <a href="https://ai.meta.com/resources/models-and-libraries/dinov3-downloads/">[link]</a>,
503
+ <a href="https://dl.fbaipublicfiles.com/dinov3/thirdparty/bpe_simple_vocab_16e6.txt.gz">vocabulary</a>,
504
+ <a href="https://dl.fbaipublicfiles.com/dinov2/thirdparty/LICENSE">vocabulary license</a>
505
+ </td>
506
+ </tr>
507
+ </tbody>
508
+ </table>
509
+
510
+ The (full) dino.txt model can be loaded via PyTorch Hub:
511
+
512
+ ```python
513
+ import torch
514
+ # DINOv3
515
+ dinov3_vitl16_dinotxt_tet1280d20h24l, tokenizer = torch.hub.load(REPO_DIR, 'dinov3_vitl16_dinotxt_tet1280d20h24l', weights=<SEGMENTOR/CHECKPOINT/URL/OR/PATH>, backbone_weights=<BACKBONE/CHECKPOINT/URL/OR/PATH>)
516
+ ```
517
+
518
+
519
+ ## Installation
520
+
521
+ The training and evaluation code requires PyTorch version >= 2.7.1 as well as a few other 3rd party packages. Note that the code has only been tested with the specified versions and also expects a Linux environment. To setup all the required dependencies for training and evaluation, please follow the instructions below:
522
+
523
+ *[micromamba](https://mamba.readthedocs.io/en/latest/user_guide/micromamba.html)* **(Recommended)** - Clone the repository and then create and activate a `dinov3` conda environment using the provided environment definition:
524
+
525
+ ```shell
526
+ micromamba env create -f conda.yaml
527
+ micromamba activate dinov3
528
+ ```
529
+
530
+ ## Getting started
531
+
532
+ Several notebooks are provided to get started applying DINOv3:
533
+ - [PCA of patch features](notebooks/pca.ipynb): display the PCA of DINOv3 patch features on a foreground object (rainbow visualizations from the paper) [[Run in Google Colab]](https://colab.research.google.com/github/facebookresearch/dinov3/blob/main/notebooks/pca.ipynb)
534
+ - [Foreground segmentation](notebooks/foreground_segmentation.ipynb): train a linear foreground segmentation model based on DINOv3 features [[Run in Google Colab]](https://colab.research.google.com/github/facebookresearch/dinov3/blob/main/notebooks/foreground_segmentation.ipynb)
535
+ - [Dense and sparse matching](notebooks/dense_sparse_matching.ipynb): match patches from objects on two different images based on DINOv3 features [[Run in Google Colab]](https://colab.research.google.com/github/facebookresearch/dinov3/blob/main/notebooks/dense_sparse_matching.ipynb)
536
+ - [Segmentation tracking](notebooks/segmentation_tracking.ipynb): video segmentation tracking using a non-parametric method based on DINOv3 features [[Run in Google Colab]](https://colab.research.google.com/github/facebookresearch/dinov3/blob/main/notebooks/segmentation_tracking.ipynb)
537
+
538
+ ## Data preparation
539
+
540
+ ### ImageNet-1k
541
+
542
+ The root directory of the dataset should hold the following contents:
543
+
544
+ - `<ROOT>/test/ILSVRC2012_test_00000001.JPEG`
545
+ - `<ROOT>/test/[..]`
546
+ - `<ROOT>/test/ILSVRC2012_test_00100000.JPEG`
547
+ - `<ROOT>/train/n01440764/n01440764_10026.JPEG`
548
+ - `<ROOT>/train/[...]`
549
+ - `<ROOT>/train/n15075141/n15075141_9993.JPEG`
550
+ - `<ROOT>/val/n01440764/ILSVRC2012_val_00000293.JPEG`
551
+ - `<ROOT>/val/[...]`
552
+ - `<ROOT>/val/n15075141/ILSVRC2012_val_00049174.JPEG`
553
+ - `<ROOT>/labels.txt`
554
+
555
+ The provided dataset implementation expects a few additional metadata files to be present under the extra directory:
556
+
557
+ - `<EXTRA>/class-ids-TRAIN.npy`
558
+ - `<EXTRA>/class-ids-VAL.npy`
559
+ - `<EXTRA>/class-names-TRAIN.npy`
560
+ - `<EXTRA>/class-names-VAL.npy`
561
+ - `<EXTRA>/entries-TEST.npy`
562
+ - `<EXTRA>/entries-TRAIN.npy`
563
+ - `<EXTRA>/entries-VAL.npy`
564
+
565
+ These metadata files can be generated (once) with the following lines of Python code:
566
+
567
+ ```python
568
+ from dinov3.data.datasets import ImageNet
569
+
570
+ for split in ImageNet.Split:
571
+ dataset = ImageNet(split=split, root="<ROOT>", extra="<EXTRA>")
572
+ dataset.dump_extra()
573
+ ```
574
+
575
+ Note that the root and extra directories do not have to be distinct directories.
576
+
577
+ ### ImageNet-22k
578
+
579
+ Please adapt the [dataset class](dinov3/data/datasets/image_net_22k.py) to match your local setup.
580
+
581
+ <br />
582
+
583
+ :warning: To execute the commands provided in the next sections for training and evaluation, the `dinov3` package should be included in the Python module search path, i.e. simply prefix the command to run with `PYTHONPATH=.`.
584
+
585
+ ## Training
586
+
587
+ ### Fast setup: training DINOv3 ViT-L/16 on ImageNet-1k
588
+
589
+ Run DINOv3 pre-training on 4 H100-80GB nodes (32 GPUs) in a SLURM cluster environment with submitit:
590
+
591
+ ```shell
592
+ PYTHONPATH=${PWD} python -m dinov3.run.submit dinov3/train/train.py \
593
+ --nodes 4 \
594
+ --config-file dinov3/configs/train/vitl_im1k_lin834.yaml \
595
+ --output-dir <PATH/TO/OUTPUT/DIR> \
596
+ train.dataset_path=ImageNet22k:root=<PATH/TO/DATASET>:extra=<PATH/TO/DATASET>
597
+ ```
598
+ Training time is approximately 14 hours and the resulting checkpoint should reach 82.0% on k-NN eval and 83.5% on linear eval.
599
+
600
+ The training code saves the weights of the teacher in the eval folder every 12500 iterations for evaluation.
601
+
602
+ ### Exact DINOv3 setup: training DINOv3 ViT-7B/16
603
+
604
+ DINOv3 ViT-7B/16 is trained on a private dataset. The training involves 3 stages:
605
+ - Pretraining
606
+ - Gram anchoring
607
+ - High resolution adaptation
608
+
609
+ #### Pretraining
610
+
611
+ Launch DINOV3 ViT-7B/16 pretraining on 32 nodes (256 GPUs) in a SLURM cluster environment with submitit.
612
+
613
+ ```shell
614
+ PYTHONPATH=${PWD} python -m dinov3.run.submit dinov3/train/train.py \
615
+ --nodes 32 \
616
+ --config-file dinov3/configs/train/dinov3_vit7b16_pretrain.yaml \
617
+ --output-dir <PATH/TO/OUTPUT/DIR> \
618
+ train.dataset_path=<DATASET>:root=<PATH/TO/DATASET>:extra=<PATH/TO/DATASET>
619
+ ```
620
+
621
+ #### Gram anchoring
622
+
623
+ ```shell
624
+ PYTHONPATH=${PWD} python -m dinov3.run.submit dinov3/train/train.py \
625
+ --nodes 32 \
626
+ --config-file dinov3/configs/train/dinov3_vit7b16_gram_anchor.yaml \
627
+ --output-dir <PATH/TO/OUTPUT/DIR> \
628
+ train.dataset_path=<DATASET>:root=<PATH/TO/DATASET>:extra=<PATH/TO/DATASET> \
629
+ gram.ckpt=<PATH/TO/GRAM_TEACHER_FROM_PREVIOUS_STEP>
630
+ ```
631
+
632
+ #### High-resolution adaptation
633
+
634
+
635
+ ```shell
636
+ PYTHONPATH=${PWD} python -m dinov3.run.submit dinov3/train/train.py \
637
+ --nodes 32 \
638
+ --config-file dinov3/configs/train/dinov3_vit7b16_high_res_adapt.yaml \
639
+ --output-dir <PATH/TO/OUTPUT/DIR> \
640
+ train.dataset_path=<DATASET>:root=<PATH/TO/DATASET>:extra=<PATH/TO/DATASET> \
641
+ gram.ckpt=<PATH/TO/TEACHER_FROM_GRAM> \
642
+ student.resume_from_teacher_chkpt=<PATH/TO/TEACHER_FROM_GRAM>
643
+ ```
644
+
645
+ ## Multi-distillation
646
+
647
+ ### Test setup:
648
+
649
+ ```shell
650
+ PYTHONPATH=${PWD} python -m dinov3.run.submit dinov3/train/train.py \
651
+ --nodes 1 \
652
+ --config-file dinov3/configs/train/multi_distillation_test.yaml \
653
+ --output-dir <PATH/TO/OUTPUT/DIR> \
654
+ --multi-distillation \
655
+ train.dataset_path=<DATASET>:root=<PATH/TO/DATASET>:extra=<PATH/TO/DATASET>
656
+ ```
657
+
658
+ ## Evaluation
659
+
660
+ The training code regularly saves the teacher weights. In order to evaluate the model, run the following evaluation on a single node:
661
+
662
+
663
+ ### Logistic regression classification on ImageNet-1k
664
+
665
+ ```shell
666
+ PYTHONPATH=${PWD} python -m dinov3.run.submit dinov3/eval/log_regression.py \
667
+ model.config_file=<PATH/TO/OUTPUT/DIR>/config.yaml \
668
+ model.pretrained_weights=<PATH/TO/OUTPUT/DIR>/teacher_checkpoint.pth \
669
+ output_dir=<PATH/TO/OUTPUT/DIR> \
670
+ train.dataset=ImageNet:split=TRAIN:root=<PATH/TO/DATASET>:extra=<PATH/TO/DATASET> \
671
+ eval.test_dataset=ImageNet:split=VAL:root=<PATH/TO/DATASET>:extra=<PATH/TO/DATASET>
672
+ ```
673
+
674
+ ### k-NN classification on ImageNet-1k
675
+
676
+ ```shell
677
+ PYTHONPATH=${PWD} python -m dinov3.run.submit dinov3/eval/knn.py \
678
+ model.config_file=<PATH/TO/OUTPUT/DIR>/config.yaml \
679
+ model.pretrained_weights=<PATH/TO/OUTPUT/DIR>/teacher_checkpoint.pth \
680
+ output_dir=<PATH/TO/OUTPUT/DIR> \
681
+ train.dataset=ImageNet:split=TRAIN:root=<PATH/TO/DATASET>:extra=<PATH/TO/DATASET> \
682
+ eval.test_dataset=ImageNet:split=VAL:root=<PATH/TO/DATASET>:extra=<PATH/TO/DATASET>
683
+ ```
684
+
685
+ ### Linear classification with data augmentation on ImageNet-1k
686
+
687
+ ```shell
688
+ PYTHONPATH=${PWD} python -m dinov3.run.submit dinov3/eval/linear.py \
689
+ model.config_file=<PATH/TO/OUTPUT/DIR>/config.yaml \
690
+ model.pretrained_weights=<PATH/TO/OUTPUT/DIR>/teacher_checkpoint.pth \
691
+ output_dir=<PATH/TO/OUTPUT/DIR> \
692
+ train.dataset=ImageNet:split=TRAIN:root=<PATH/TO/DATASET>:extra=<PATH/TO/DATASET> \
693
+ train.val_dataset=ImageNet:split=VAL:root=<PATH/TO/DATASET>:extra=<PATH/TO/DATASET>
694
+ ```
695
+
696
+
697
+ ### Text alignment on DINOv3 using dino.txt
698
+
699
+ Text alignment can be done following the method from `dino.txt` aka [DINOv2 Meets Text](https://arxiv.org/abs/2412.16334).
700
+
701
+ ```shell
702
+ PYTHONPATH=${PWD} python -m dinov3.run.submit dinov3/eval/text/train_dinotxt.py \
703
+ --nodes 4 \
704
+ # An example config for text alignment is here: dinov3/eval/text/configs/dinov3_vitl_text.yaml \
705
+ trainer_config_file="<PATH/TO/DINOv3/TEXT/CONFIG>" \
706
+ output-dir=<PATH/TO/OUTPUT/DIR>
707
+ ```
708
+ Launching the above trains text alignment on 4 nodes with 8 gpus each (32 gpus in total).
709
+ Please note that the text alignment model in the DINOv3 paper was trained on a private dataset and here we have given an example config in ```dinov3/eval/text/configs/dinov3_vitl_text.yaml``` using ```CocoCaptions``` dataset for illustration purposes.
710
+ Please adapt the provided ```CocoCaptions``` dataset class, the dataset can be found [here](https://www.kaggle.com/datasets/nikhil7280/coco-image-caption)
711
+
712
+ ## License
713
+
714
+ DINOv3 code and model weights are released under the DINOv3 License. See [LICENSE.md](LICENSE.md) for additional details.
715
+
716
+ ## Contributing
717
+
718
+ See [contributing](CONTRIBUTING.md) and the [code of conduct](CODE_OF_CONDUCT.md).
719
+
720
+ ## Citing DINOv3
721
+
722
+ If you find this repository useful, please consider giving a star :star: and citation :t-rex::
723
+
724
+ ```
725
+ @misc{simeoni2025dinov3,
726
+ title={{DINOv3}},
727
+ author={Sim{\'e}oni, Oriane and Vo, Huy V. and Seitzer, Maximilian and Baldassarre, Federico and Oquab, Maxime and Jose, Cijo and Khalidov, Vasil and Szafraniec, Marc and Yi, Seungeun and Ramamonjisoa, Micha{\"e}l and Massa, Francisco and Haziza, Daniel and Wehrstedt, Luca and Wang, Jianyuan and Darcet, Timoth{\'e}e and Moutakanni, Th{\'e}o and Sentana, Leonel and Roberts, Claire and Vedaldi, Andrea and Tolan, Jamie and Brandt, John and Couprie, Camille and Mairal, Julien and J{\'e}gou, Herv{\'e} and Labatut, Patrick and Bojanowski, Piotr},
728
+ year={2025},
729
+ eprint={2508.10104},
730
+ archivePrefix={arXiv},
731
+ primaryClass={cs.CV},
732
+ url={https://arxiv.org/abs/2508.10104},
733
+ }
734
+ ```
InfiniDepth/model/block/torchhub/dinov3/conda.yaml ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: dinov3
2
+ channels:
3
+ - defaults
4
+ - conda-forge
5
+ dependencies:
6
+ - python=3.11
7
+ - omegaconf
8
+ - pip
9
+ - pip:
10
+ - ftfy # needed for dino.txt
11
+ - iopath
12
+ - omegaconf
13
+ - pandas
14
+ - regex # needed for dino.txt
15
+ - pandas
16
+ - scikit-learn
17
+ - scikit-learn-intelex
18
+ - submitit
19
+ - termcolor
20
+ - torch
21
+ - torchvision
22
+ - torchmetrics
23
+
InfiniDepth/model/block/torchhub/dinov3/dinov3/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This software may be used and distributed in accordance with
4
+ # the terms of the DINOv3 License Agreement.
5
+
6
+ __version__ = "0.0.1"
InfiniDepth/model/block/torchhub/dinov3/dinov3/checkpointer/__init__.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This software may be used and distributed in accordance with
4
+ # the terms of the DINOv3 License Agreement.
5
+
6
+ from .checkpointer import (
7
+ CheckpointRetentionPolicy,
8
+ cleanup_checkpoint,
9
+ find_all_checkpoints,
10
+ find_latest_checkpoint,
11
+ init_fsdp_model_from_checkpoint,
12
+ init_model_from_checkpoint_for_evals,
13
+ keep_checkpoint_copy,
14
+ keep_last_n_checkpoints,
15
+ load_checkpoint,
16
+ register_dont_save_hooks,
17
+ save_checkpoint,
18
+ )
InfiniDepth/model/block/torchhub/dinov3/dinov3/checkpointer/checkpointer.py ADDED
@@ -0,0 +1,349 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This software may be used and distributed in accordance with
4
+ # the terms of the DINOv3 License Agreement.
5
+
6
+ """
7
+ Suggested file structure:
8
+
9
+ output_dir/
10
+ |-- ckpt/
11
+ | |-- 0/
12
+ | |-- 99/
13
+ | |-- 199/
14
+ | |-- 199_keep/
15
+ | |-- 299/
16
+ | `-- ...
17
+ `-- eval/
18
+ `-- 0/
19
+ `-- 99/
20
+ `-- ckpt/
21
+
22
+ Distributed checkpointer docs:
23
+ - https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html
24
+ - https://pytorch.org/docs/stable/distributed.checkpoint.html
25
+ """
26
+
27
+ import logging
28
+ import shutil
29
+ import subprocess
30
+ import tempfile
31
+ from enum import Enum
32
+ from pathlib import Path
33
+ from typing import List, Sequence, Set
34
+
35
+ import torch
36
+ import torch.distributed as dist
37
+ import torch.distributed.checkpoint as dcp
38
+ import torch.distributed.checkpoint.filesystem as dcpfs
39
+ import torch.distributed.checkpoint.state_dict as dcpsd
40
+ from torch.distributed.checkpoint.stateful import Stateful
41
+
42
+ logger = logging.getLogger(__name__)
43
+
44
+
45
+ class CheckpointRetentionPolicy(Enum):
46
+ ALL = "all" # keep all checkpoints
47
+ BEST = "best"
48
+ LAST = "last"
49
+ LAST_AND_BEST = "last_and_best"
50
+ NONE = "none" # do not keep any checkpoints
51
+
52
+ @property
53
+ def keep_filters(self) -> Set[str]:
54
+ """Files that match these patterns are not deleted by cleanup"""
55
+ if self == CheckpointRetentionPolicy.LAST:
56
+ return set(["final"])
57
+ if self == CheckpointRetentionPolicy.BEST:
58
+ return set(["best"])
59
+ if self == CheckpointRetentionPolicy.LAST_AND_BEST:
60
+ return set(["final", "best"])
61
+ if self == CheckpointRetentionPolicy.ALL:
62
+ return set()
63
+ return set()
64
+
65
+ @property
66
+ def max_to_keep(self) -> int | None:
67
+ """
68
+ maximum "periodic" checkpoints to keep concurrently, ie. saved with `step` and not `save`. `None` for keep all
69
+ """
70
+ if self == CheckpointRetentionPolicy.ALL:
71
+ return None
72
+ return 1
73
+
74
+
75
+ def save_checkpoint(
76
+ ckpt_dir: str | Path, # output_dir/ckpt/199
77
+ *,
78
+ iteration: int | str,
79
+ model: torch.nn.Module,
80
+ optimizer: torch.optim.Optimizer | None = None,
81
+ overwrite: bool = True,
82
+ process_group: dist.ProcessGroup = None,
83
+ **others: Stateful,
84
+ ):
85
+ """Save a plain/DDP/FSDP/FSDP2 model, its optimizer, an integer iteration and other stateful objects."""
86
+ rank = torch.distributed.get_rank(group=process_group)
87
+
88
+ # Rank 0 checks if the checkpoint directory exists, but all ranks need to know if if exists,
89
+ # so they can raise an error when overwrite is False. If overwrite is True, rank 0 will delete it
90
+ # and other ranks wait for the deletion to finish.
91
+ ckpt_dir = Path(ckpt_dir)
92
+ ckpt_dir_exists = [ckpt_dir.exists() if rank == 0 else None]
93
+ src_rank = 0
94
+ if process_group is not None:
95
+ src_rank = torch.distributed.get_global_rank(group=process_group, group_rank=0)
96
+ torch.distributed.broadcast_object_list(ckpt_dir_exists, src=src_rank, group=process_group)
97
+ ckpt_dir_exists = ckpt_dir_exists[0]
98
+ if ckpt_dir_exists:
99
+ if overwrite:
100
+ if rank == 0:
101
+ if ckpt_dir.is_dir():
102
+ shutil.rmtree(ckpt_dir)
103
+ else:
104
+ ckpt_dir.unlink()
105
+ logger.info(f"Deleted: {ckpt_dir}")
106
+ torch.distributed.barrier(group=process_group)
107
+ else:
108
+ raise RuntimeError(f"Checkpoint already exists: {ckpt_dir}")
109
+
110
+ # Rank 0 creates a temporary directory for the checkpoint and broadcasts the name to all ranks.
111
+ ckpt_dir.parent.mkdir(parents=True, exist_ok=True)
112
+ ckpt_dir_tmp = [tempfile.mkdtemp(dir=ckpt_dir.parent, prefix=ckpt_dir.name) if rank == 0 else None]
113
+ torch.distributed.broadcast_object_list(ckpt_dir_tmp, src=src_rank, group=process_group)
114
+ ckpt_dir_tmp = Path(ckpt_dir_tmp[0])
115
+
116
+ to_save = {"iteration": iteration}
117
+ to_save["model"] = dcpsd.get_model_state_dict(model)
118
+ if optimizer is not None:
119
+ to_save["optimizer"] = dcpsd.get_optimizer_state_dict(model, optimizer)
120
+ to_save.update(others)
121
+ dcp.save(
122
+ to_save,
123
+ storage_writer=dcpfs.FileSystemWriter(ckpt_dir_tmp),
124
+ process_group=process_group,
125
+ )
126
+
127
+ # Rank 0 renames the temporary directory to the final checkpoint directory. All ranks wait for the rename.
128
+ if rank == 0:
129
+ ckpt_dir_tmp.rename(ckpt_dir)
130
+ torch.distributed.barrier()
131
+
132
+ logger.info(f"Saved: {ckpt_dir}")
133
+
134
+
135
+ def load_checkpoint(
136
+ ckpt_dir: str | Path, # output_dir/ckpt/199
137
+ *,
138
+ model: torch.nn.Module,
139
+ optimizer: torch.optim.Optimizer | None = None,
140
+ strict_loading: bool = True,
141
+ process_group: dist.ProcessGroup = None,
142
+ **others: Stateful,
143
+ ) -> int | None:
144
+ """
145
+ Load a plain/DDP/FSDP/FSDP2 model, its optimizer, an integer iteration and other stateful objects.
146
+ Can you take a checkpoint saved on N ranks and load it on M ranks? Sure you can!
147
+ Activation checkpointing and torch-compile can also be different between save and load, no problem.
148
+ """
149
+ ckpt_dir = Path(ckpt_dir)
150
+ to_load = {"iteration": None}
151
+ to_load["model"] = dcpsd.get_model_state_dict(model)
152
+ if optimizer is not None:
153
+ to_load["optimizer"] = dcpsd.get_optimizer_state_dict(model, optimizer)
154
+ to_load.update(others)
155
+ dcp.load(
156
+ to_load,
157
+ storage_reader=dcpfs.FileSystemReader(ckpt_dir),
158
+ planner=dcp.default_planner.DefaultLoadPlanner(allow_partial_load=not strict_loading),
159
+ process_group=process_group,
160
+ )
161
+ iteration = to_load["iteration"]
162
+ dcpsd.set_model_state_dict(model, to_load["model"])
163
+ if optimizer is not None:
164
+ dcpsd.set_optimizer_state_dict(model, optimizer, to_load["optimizer"])
165
+ logger.info(f"Loaded: {ckpt_dir}")
166
+ return iteration
167
+
168
+
169
+ def register_dont_save_hooks(module: torch.nn.Module, dont_save: Sequence[str]):
170
+ """
171
+ Registers save/load state dict hooks such that the weights in `dont_save` are not persisted in the checkpoint.
172
+
173
+ Typical use case: a classification model composed of a frozen backbone and a trainable head.
174
+ If the frozen backbone is loaded from torch hub, it does't make sense to save a copy of it in each checkpoint.
175
+ """
176
+
177
+ def state_dict_post_hook(module, state_dict, prefix, local_metadata):
178
+ # Remove frozen weights so they won't get saved.
179
+ # If this module is not the top-level module, its weights will have a prefix in the state dict.
180
+ nonlocal _dont_save
181
+ for k in _dont_save:
182
+ del state_dict[prefix + k]
183
+
184
+ def load_state_dict_pre_hook(
185
+ module,
186
+ state_dict,
187
+ prefix,
188
+ local_metadata,
189
+ strict,
190
+ missing_keys,
191
+ unexpected_keys,
192
+ error_msgs,
193
+ ):
194
+ # This pre hook exists only to pass the prefix to the post hook when loading the state dict.
195
+ nonlocal _prefix
196
+ assert _prefix is None
197
+ _prefix = prefix
198
+
199
+ def load_state_dict_post_hook(module, incompatible_keys):
200
+ # Remove the frozen weights from the missing keys so they don't raise an error.
201
+ nonlocal _prefix
202
+ assert _prefix is not None
203
+ to_remove = []
204
+ for missing_key in incompatible_keys.missing_keys:
205
+ k = missing_key.removeprefix(_prefix)
206
+ k = k.replace("_checkpoint_wrapped_module.", "") # Added by activation checkpointing
207
+ if k in _dont_save:
208
+ to_remove.append(missing_key)
209
+ for r in to_remove:
210
+ incompatible_keys.missing_keys.remove(r)
211
+ _prefix = None
212
+
213
+ _dont_save = set(name.replace("_checkpoint_wrapped_module.", "") for name in dont_save)
214
+ _prefix = None
215
+ module.register_state_dict_post_hook(state_dict_post_hook)
216
+ module.register_load_state_dict_pre_hook(load_state_dict_pre_hook)
217
+ module.register_load_state_dict_post_hook(load_state_dict_post_hook)
218
+
219
+
220
+ def find_all_checkpoints(ckpt_dir: Path | str) -> list[Path]:
221
+ """Find all checkpoints in a directory, i.e. subdirs with integer name. Sorted from first to last."""
222
+ ckpt_dir = Path(ckpt_dir)
223
+ if not ckpt_dir.is_dir():
224
+ return []
225
+ checkpoints = [p for p in ckpt_dir.iterdir() if p.is_dir() and _is_int(p.name)]
226
+ checkpoints.sort(key=lambda p: int(p.name))
227
+ return checkpoints
228
+
229
+
230
+ def find_latest_checkpoint(ckpt_dir: Path | str) -> Path | None:
231
+ """Find the latest checkpoint in a directory, i.e. the subdir with the highest integer name."""
232
+ checkpoints = find_all_checkpoints(ckpt_dir)
233
+ if len(checkpoints) == 0:
234
+ return None
235
+ return checkpoints[-1]
236
+
237
+
238
+ def keep_last_n_checkpoints(ckpt_dir: Path | str, n: int | None):
239
+ """In a directory with integer-named subdirs, keep only the n subdirs with the highest number."""
240
+ if n is None:
241
+ return
242
+ checkpoints = find_all_checkpoints(ckpt_dir)
243
+ for ckpt_dir in checkpoints[:-n]:
244
+ try:
245
+ shutil.rmtree(ckpt_dir)
246
+ logger.info(f"Deleted: {ckpt_dir}")
247
+ except Exception:
248
+ logger.exception(f"Failed to delete: {ckpt_dir}")
249
+
250
+
251
+ def keep_checkpoint_copy(src: Path | str):
252
+ """Copy a file/directory next to itself with a _keep suffix. Files are hardlinked."""
253
+ src = Path(src)
254
+ dst = src.parent / f"{src.name}_keep"
255
+ subprocess.check_output(["cp", "--recursive", "--link", src, dst])
256
+ logger.info(f"Copied: {src} -> {dst}")
257
+
258
+
259
+ def _is_int(s: str) -> bool:
260
+ try:
261
+ int(s)
262
+ return True
263
+ except ValueError:
264
+ return False
265
+
266
+
267
+ # Initialize a FSDP2 model from DCP or PyTorch standard checkpoint
268
+ def init_fsdp_model_from_checkpoint(
269
+ model: torch.nn.Module,
270
+ checkpoint_path: str,
271
+ skip_load_prefixes: List[str] | None = None,
272
+ prefixes_not_sharded: List[str] | None = None,
273
+ process_group: dist.ProcessGroup = None,
274
+ ):
275
+ if not Path(checkpoint_path).is_dir(): # PyTorch standard checkpoint
276
+ logger.info(f"Loading pretrained weights from {checkpoint_path}")
277
+ chkpt = torch.load(checkpoint_path, map_location="cpu")["teacher"]
278
+ from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
279
+
280
+ if process_group is None:
281
+ world_mesh = init_device_mesh(
282
+ "cuda",
283
+ mesh_shape=(dist.get_world_size(),),
284
+ mesh_dim_names=("dp",),
285
+ )
286
+ else:
287
+ world_mesh = DeviceMesh.from_group(process_group, "cuda")
288
+ chkpt = {
289
+ k: (
290
+ torch.distributed.tensor.distribute_tensor(v, world_mesh, src_data_rank=None)
291
+ if not k.startswith(pns)
292
+ else v
293
+ )
294
+ for pns in prefixes_not_sharded
295
+ for k, v in chkpt.items()
296
+ }
297
+ model.load_state_dict(
298
+ {k: v for k, v in chkpt.items() if not any(k.startswith(prefix) for prefix in skip_load_prefixes)}
299
+ )
300
+ else: # DCP checkpoint
301
+ load_checkpoint(ckpt_dir=checkpoint_path, model=model, process_group=process_group)
302
+
303
+
304
+ # Initialize a standard non distributed PyTorch model from PyTorch standard checkpoint for evals
305
+ def init_model_from_checkpoint_for_evals(
306
+ model: torch.nn.Module, pretrained_weights: str | Path, checkpoint_key: str = None
307
+ ):
308
+ state_dict = torch.load(pretrained_weights, map_location="cpu")
309
+ if checkpoint_key is not None and checkpoint_key in state_dict:
310
+ logger.info(f"Take key {checkpoint_key} in provided checkpoint dict")
311
+ state_dict = state_dict[checkpoint_key]
312
+ # remove `module.` prefix
313
+ state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
314
+ # remove `backbone.` prefix induced by multicrop wrapper
315
+ state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
316
+ msg = model.load_state_dict(state_dict, strict=False)
317
+ logger.info("Pretrained weights found at {} and loaded with msg: {}".format(pretrained_weights, msg))
318
+
319
+
320
+ def cleanup_checkpoint(ckpt_dir: str, checkpoint_retention_policy: CheckpointRetentionPolicy):
321
+ """
322
+ ckpt_dir is the directory containing each individual checkpoint directories (either at iteration, best (validation performance) or final)
323
+ |-- ckpt_dir/
324
+ | |-- 0/
325
+ | |--checkpoint.pth or dcp_sharded_checkpoint_dir
326
+ | |-- 99/
327
+ |--checkpoint.pth or dcp_sharded_checkpoint_dir
328
+ | |-- 199/
329
+ |--checkpoint.pth or dcp_sharded_checkpoint_dir
330
+ | |-- best/
331
+ |--checkpoint.pth or dcp_sharded_checkpoint_dir
332
+ | |-- 299/
333
+ |--checkpoint.pth or dcp_sharded_checkpoint_dir
334
+ | |-- final/
335
+ |--checkpoint.pth or dcp_sharded_checkpoint_dir
336
+ """
337
+ ckpt_dir = Path(ckpt_dir)
338
+ if not ckpt_dir.is_dir():
339
+ return []
340
+ checkpoint_filters = checkpoint_retention_policy.keep_filters
341
+ checkpoints = [p for p in ckpt_dir.iterdir() if p.is_dir()]
342
+ for checkpoint in checkpoints:
343
+ if checkpoint in checkpoint_filters:
344
+ continue
345
+ try:
346
+ shutil.rmtree(checkpoint)
347
+ logger.info(f"Deleted: {checkpoint}")
348
+ except Exception:
349
+ logger.exception(f"Failed to delete: {checkpoint}")
InfiniDepth/model/block/torchhub/dinov3/dinov3/configs/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This software may be used and distributed in accordance with
4
+ # the terms of the DINOv3 License Agreement.
5
+
6
+ from .config import (
7
+ DinoV3SetupArgs,
8
+ apply_scaling_rules_to_cfg,
9
+ exit_job,
10
+ get_cfg_from_args,
11
+ get_default_config,
12
+ setup_config,
13
+ setup_job,
14
+ setup_multidistillation,
15
+ write_config,
16
+ )
InfiniDepth/model/block/torchhub/dinov3/dinov3/configs/config.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This software may be used and distributed in accordance with
4
+ # the terms of the DINOv3 License Agreement.
5
+
6
+ import logging
7
+ import math
8
+ import os
9
+ import pathlib
10
+ import sys
11
+ from dataclasses import dataclass, field
12
+ from datetime import timedelta
13
+ from typing import Any, List, Optional, Sequence, Tuple
14
+
15
+ from omegaconf import DictConfig, OmegaConf
16
+
17
+ import dinov3.distributed as distributed
18
+ from dinov3.logging import cleanup_logging, setup_logging
19
+ from dinov3.utils import fix_random_seeds, get_conda_env, get_sha
20
+
21
+ logger = logging.getLogger("dinov3")
22
+
23
+
24
+ @dataclass
25
+ class DinoV3SetupArgs:
26
+ config_file: str
27
+ pretrained_weights: str | None = None
28
+ shard_unsharded_model: bool = False
29
+ output_dir: str = ""
30
+ opts: List[Any] = field(default_factory=lambda: [])
31
+
32
+ def __post_init__(self):
33
+ # When loaded from benchmark.yaml, self.opts is a frozen omegaconf.ListConfig,
34
+ # which works everywhere except when we want to modify it or when
35
+ # we try to json-serialize it. So we convert it to a regular list here.
36
+ if OmegaConf.is_config(self.opts):
37
+ self.opts = OmegaConf.to_object(self.opts)
38
+
39
+
40
+ def apply_scaling_rules_to_cfg(cfg): # to fix
41
+ assert distributed.is_enabled(), "Setup distributed to get global size !"
42
+ if "schedules" in cfg:
43
+ # For schedules v2, the scaling rules are applied when building the schedules, the config is not modified
44
+ return cfg
45
+
46
+ if cfg.optim.scaling_rule == "linear_wrt_256":
47
+ old_lr = cfg.optim.lr
48
+ cfg.optim.lr *= cfg.train.batch_size_per_gpu * distributed.get_world_size() / 256.0
49
+ logger.info(f"linear scaling learning rate; old: {old_lr}, new: {cfg.optim.lr}")
50
+ elif cfg.optim.scaling_rule == "sqrt_wrt_1024":
51
+ old_lr = cfg.optim.lr
52
+ cfg.optim.lr *= 4 * math.sqrt(cfg.train.batch_size_per_gpu * distributed.get_world_size() / 1024.0)
53
+ logger.info(f"sqrt scaling learning rate; old: {old_lr}, new: {cfg.optim.lr}")
54
+ return cfg
55
+
56
+
57
+ def write_config(cfg, output_dir, name="config.yaml"):
58
+ logger.info(OmegaConf.to_yaml(cfg))
59
+ output_dir = os.path.abspath(output_dir)
60
+ saved_cfg_path = os.path.join(output_dir, name)
61
+ with open(saved_cfg_path, "w") as f:
62
+ OmegaConf.save(config=cfg, f=f)
63
+ return saved_cfg_path
64
+
65
+
66
+ def get_default_config() -> DictConfig:
67
+ p = pathlib.Path(__file__).parent / "ssl_default_config.yaml"
68
+ return OmegaConf.load(p)
69
+
70
+
71
+ def get_cfg_from_args(args: DinoV3SetupArgs, multidistillation=False, strict=True):
72
+ overrides = [*args.opts]
73
+ if args.output_dir is not None:
74
+ overrides.append(f"train.output_dir={os.path.realpath(args.output_dir)}")
75
+
76
+ # Config file
77
+ cfg = OmegaConf.load(args.config_file)
78
+
79
+ # Command line overrides
80
+ opts_cfg = OmegaConf.from_cli(overrides)
81
+
82
+ if multidistillation:
83
+ cfg = OmegaConf.merge(cfg, opts_cfg)
84
+ else:
85
+ # Default config
86
+ default_cfg = get_default_config()
87
+ if strict:
88
+ OmegaConf.set_struct(default_cfg, True)
89
+ cfg = OmegaConf.merge(default_cfg, cfg, opts_cfg)
90
+ return cfg
91
+
92
+
93
+ def setup_config(args: DinoV3SetupArgs, strict_cfg=True):
94
+ """
95
+ Create configs and perform basic setups.
96
+ """
97
+ # Create the cfg with OmegaConf
98
+ cfg = get_cfg_from_args(args, strict=strict_cfg)
99
+ # setup distributed, logging, and random seeds
100
+ logger.info("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items())))
101
+ # dump config before modifying so it can be reloaded
102
+ if args.output_dir is not None:
103
+ write_config(cfg, args.output_dir)
104
+ # modify the config inplace by applying scaling rules
105
+ apply_scaling_rules_to_cfg(cfg)
106
+ return cfg
107
+
108
+
109
+ def _enumerate_all_subgroup_ranks(all_subgroup_rank_spans: Sequence[Tuple[int, int]]):
110
+ """Expands a specification of process subgroups from spans to enumerated ranks.
111
+
112
+ Args:
113
+ all_group_rank_spans: a sequence of rank spans (first rank, last rank),
114
+ one for each process group. Example: ((0, 1), (2, 3), (4, 7)).
115
+ """
116
+ for first, last in all_subgroup_rank_spans:
117
+ assert first <= last
118
+ return tuple(tuple(range(first, last + 1)) for first, last in all_subgroup_rank_spans)
119
+
120
+
121
+ def setup_multidistillation(args: DinoV3SetupArgs):
122
+ base_output_dir = args.output_dir
123
+ os.makedirs(args.output_dir, exist_ok=True)
124
+ # get config file for this rank
125
+ base_cfg = OmegaConf.load(args.config_file)
126
+ assert base_cfg.multidistillation.enabled
127
+
128
+ global_batch_size = base_cfg.multidistillation.global_batch_size
129
+
130
+ distributed.enable(overwrite=True)
131
+ seed = getattr(args, "seed", 0)
132
+ rank = distributed.get_rank()
133
+
134
+ # build process subgroups
135
+ all_subgroup_rank_spans = tuple(
136
+ (student.ranks_range[0], student.ranks_range[1] - 1) for student in base_cfg.multidistillation.students
137
+ )
138
+ all_subgroup_ranks = _enumerate_all_subgroup_ranks(all_subgroup_rank_spans)
139
+ distributed.new_subgroups(all_subgroup_ranks)
140
+
141
+ found = False
142
+ for student in base_cfg.multidistillation.students:
143
+ if rank in range(*student.ranks_range):
144
+ found = True
145
+ break
146
+ assert found, "rank of worker not in defined range"
147
+
148
+ name = student.name
149
+ config_path = student.config_path
150
+ n_gpus = student.ranks_range[1] - student.ranks_range[0]
151
+ assert global_batch_size % n_gpus == 0
152
+ total_n_gpus = distributed.get_world_size()
153
+
154
+ args.output_dir = os.path.join(base_output_dir, name)
155
+ args.opts += [f"train.output_dir={args.output_dir}"]
156
+ args.opts += [f"train.batch_size_per_gpu={global_batch_size // total_n_gpus}"]
157
+ args.config_file = os.path.abspath(config_path)
158
+ default_cfg = get_default_config()
159
+ cfg = OmegaConf.load(args.config_file)
160
+ cfg = OmegaConf.merge(default_cfg, cfg, base_cfg, OmegaConf.from_cli(args.opts))
161
+
162
+ global logger
163
+ setup_logging(output=args.output_dir, level=logging.INFO)
164
+
165
+ fix_random_seeds(seed + rank)
166
+
167
+ write_config(cfg, args.output_dir)
168
+ apply_scaling_rules_to_cfg(cfg)
169
+
170
+ return cfg
171
+
172
+
173
+ def setup_job(
174
+ output_dir: Optional[str] = None,
175
+ distributed_enabled: bool = True,
176
+ logging_enabled: bool = True,
177
+ seed: Optional[int] = 0,
178
+ restrict_print_to_main_process: bool = True,
179
+ distributed_timeout: timedelta | None = None,
180
+ ):
181
+ """
182
+ Setup methods that should be done in every fairvit job
183
+ Initializes logging, distributed, random seeds and other utilities.
184
+ """
185
+ if output_dir is not None:
186
+ output_dir = os.path.realpath(output_dir)
187
+ os.makedirs(output_dir, exist_ok=True)
188
+
189
+ if logging_enabled:
190
+ setup_logging(
191
+ output=output_dir,
192
+ level=logging.INFO,
193
+ log_to_stdout_only_in_main_process=restrict_print_to_main_process,
194
+ )
195
+
196
+ if distributed_enabled:
197
+ distributed.enable(
198
+ overwrite=True,
199
+ nccl_async_error_handling=True,
200
+ restrict_print_to_main_process=restrict_print_to_main_process,
201
+ timeout=distributed_timeout,
202
+ )
203
+
204
+ if seed is not None:
205
+ rank = distributed.get_rank()
206
+ fix_random_seeds(seed + rank)
207
+
208
+ logger = logging.getLogger("fairvit")
209
+ logger.info("git:\n {}\n".format(get_sha()))
210
+
211
+ # Log some python info
212
+ conda_env_name, conda_env_path = get_conda_env()
213
+ logger.info(f"conda env name: {conda_env_name}")
214
+ logger.info(f"conda env path: {conda_env_path}")
215
+ logger.info(f"python path: {sys.path}")
216
+
217
+
218
+ def exit_job(distributed_enabled: bool = True, logging_enabled: bool = True):
219
+ if distributed_enabled:
220
+ distributed.disable()
221
+ if logging_enabled:
222
+ cleanup_logging()
InfiniDepth/model/block/torchhub/dinov3/dinov3/configs/ssl_default_config.yaml ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MODEL:
2
+ META_ARCHITECTURE: SSLMetaArch
3
+ DEVICE: cuda
4
+ WEIGHTS: ''
5
+ DTYPE: float32
6
+ compute_precision:
7
+ param_dtype: bf16
8
+ reduce_dtype: fp32
9
+ sharding_strategy: SHARD_GRAD_OP
10
+ dino:
11
+ loss_weight: 1.0
12
+ global_ignore_diagonal: true # Whether to ignore A-A and B-B global pairs, default as in DINOv2, ignored by SSLMetaArch
13
+ head_n_prototypes: 65536
14
+ head_bottleneck_dim: 256
15
+ head_norm_last_layer: false
16
+ head_nlayers: 3
17
+ head_hidden_dim: 2048
18
+ koleo_loss_weight: 0.1
19
+ koleo_loss_distributed: false
20
+ koleo_topk: 1
21
+ koleo_distributed_replicas: 0
22
+ koleo_distributed_loss_group_size: null # Size of the nearest neighbor set for distributed Koleo. If None, uses global batch size.
23
+ koleo_distributed_loss_group_data: true # group data from adjacent ranks to make sure koleo is applied on the same data distribution
24
+ force_weight_norm: false
25
+ reweight_dino_local_loss: false # If true, reweighting of DINO loss
26
+ local_loss_weight_schedule: # Schedule for local loss weight, enabled if reweight_dino_local_loss is true
27
+ start: 0.5
28
+ peak: 0.5
29
+ end: 0.5
30
+ warmup_epochs: 0
31
+ ibot:
32
+ loss_weight: 1.0
33
+ mask_sample_probability: 0.5
34
+ mask_ratio_min_max:
35
+ - 0.1
36
+ - 0.5
37
+ mask_random_circular_shift: false
38
+ force_masking_even_with_zero_weight: False
39
+ separate_head: true
40
+ head_n_prototypes: 65536
41
+ head_bottleneck_dim: 256
42
+ head_norm_last_layer: false
43
+ head_nlayers: 3
44
+ head_hidden_dim: 2048
45
+ gram:
46
+ use_loss: false # (bool) if true gram is used, else not
47
+ compute_stats: false # (bool): if true compute auxilliary stats
48
+ loss_weight: 1.0 # (float): weight of the loss
49
+ ema_teacher: false # (bool): using the EMA teacher as GRAM teacher
50
+ ckpt: null #(str): Checkpoint to the teacher
51
+ it_load_ema_teacher: -1 # (int): iteration at which the ema teacher is loaded into the gram teacher
52
+ rep_update: true # (bool): if true GRAM teacher updated every gram.update_frequency after iter gram.it_first_update steps
53
+ update_frequency: 50000 # (int): update frequency
54
+ it_first_update: 0 # (int): iteration of the first update
55
+ max_updates: null # (int): maximum number of updates to gram teacher. If None, it is unlimited
56
+ normalized: true # (bool): normalization of the features
57
+ img_level: false # (bool): if true GRAM computation at the image else, otherwise at the local batch level
58
+ remove_neg: false # (bool): if true remove the negative similarities before applying the loss
59
+ remove_only_teacher_neg: false # (bool): remove negative similarities of the teacher
60
+ tokens_used: all # (str): In [all, masked, unmasked]
61
+ global_teacher_resize_method: bicubic # Method for resizing the outputs of the gram teacher
62
+ global_teacher_resize_antialias: false # Whether to use antialiasing when resizing the outputs of the gram teacher
63
+ loss_weight_schedule: null # (dict): If not None, use a schedule for the loss weight instead of `loss_weight`
64
+ train:
65
+ batch_size_per_gpu: 64
66
+ dataset_path: ImageNet:split=TRAIN
67
+ data_config: null
68
+ output_dir: .
69
+ saveckp_freq: 20
70
+ seed: 0
71
+ num_workers: 10
72
+ OFFICIAL_EPOCH_LENGTH: 1250
73
+ monitor_gradient_norm: false
74
+ chunk_schedule: []
75
+ use_teacher_head: true
76
+ learn_from_teacher_tokens: false
77
+ centering: "sinkhorn_knopp" # or "sinkhorn_knopp"
78
+ checkpointing: false
79
+ checkpointing_full: false # aggressive checkpointing
80
+ compile: true
81
+ cudagraphs: false
82
+ sharded_eval_checkpoint: false
83
+ cache_dataset: false
84
+ student:
85
+ arch: vit_large
86
+ patch_size: 16
87
+ drop_path_rate: 0.3
88
+ layerscale: 1.0e-05
89
+ pretrained_weights: ''
90
+ ffn_layer: "mlp"
91
+ ffn_ratio: 4.0
92
+ resume_from_teacher_chkpt: ""
93
+ qkv_bias: true
94
+ proj_bias: true
95
+ ffn_bias: true
96
+ norm_layer: "layernorm"
97
+ n_storage_tokens: 0
98
+ mask_k_bias: false
99
+ untie_cls_and_patch_norms: false # If true, use separate norms for CLS/reg and patch/mask tokens
100
+ untie_global_and_local_cls_norm: false # If true, use separate norms for local and global crop CLS token during training
101
+ in_chans: 3
102
+ pos_embed_type: rope
103
+ pos_embed_rope_base: 100.0
104
+ pos_embed_rope_min_period: null
105
+ pos_embed_rope_max_period: null
106
+ pos_embed_rope_normalize_coords: separate # min, max, separate
107
+ pos_embed_rope_shift_coords: null
108
+ pos_embed_rope_jitter_coords: null
109
+ pos_embed_rope_rescale_coords: null
110
+ pos_embed_rope_dtype: bf16
111
+ fp8_enabled: False # Convert Linear layers to operate in fp8 precision
112
+ fp8_filter: "blocks" # Regex that must appear in module path; empty means everything
113
+ teacher:
114
+ momentum_teacher: 0.992
115
+ final_momentum_teacher: 1
116
+ warmup_teacher_temp: 0.04
117
+ teacher_temp: 0.07
118
+ warmup_teacher_temp_epochs: 30
119
+ in_chans: 3
120
+ distillation: # teacher
121
+ enabled: false
122
+ full_cfg_path: ""
123
+ checkpoint_path: ""
124
+ multidistillation:
125
+ enabled: false
126
+ hrft: # non-hrft'd student
127
+ enabled: false
128
+ checkpoint_path: "" # teacher_checkpoint path
129
+ optim:
130
+ epochs: 100
131
+ optimizer: adamw
132
+ weight_decay: 0.04
133
+ weight_decay_end: 0.4
134
+ lr: 0.001
135
+ warmup_epochs: 10
136
+ min_lr: 1.0e-06
137
+ schedule_trunc_extra: 0.0 # Compute the schedule for (1 + schedule_trunc_extra) steps and truncate, .25 is a good choice
138
+ clip_grad: 3.0
139
+ freeze_last_layer_epochs: 1
140
+ scaling_rule: sqrt_wrt_1024
141
+ patch_embed_lr_mult: 0.2
142
+ dino_head_wd_multiplier: 1.0
143
+ layerwise_decay: 0.9
144
+ multi_tensor_optim: true
145
+ dump_fsdp_weights_path: ""
146
+ adamw_beta1: 0.9
147
+ adamw_beta2: 0.999
148
+ crops:
149
+ global_crops_scale:
150
+ - 0.32
151
+ - 1.0
152
+ local_crops_number: 8
153
+ local_crops_scale:
154
+ - 0.05
155
+ - 0.32
156
+ global_crops_size: 224
157
+ local_crops_size: 96
158
+ global_local_crop_pairs_ratios: 1.0
159
+ gram_teacher_crops_size: null # If not None, return crops for gram teacher
160
+ localcrops_subset_of_globalcrops: false
161
+ share_color_jitter: false
162
+ horizontal_flips: true
163
+ gram_teacher_no_distortions: false # If True, no distortions are applied to gram teacher crops
164
+ rgb_mean:
165
+ - 0.485
166
+ - 0.456
167
+ - 0.406
168
+ rgb_std:
169
+ - 0.229
170
+ - 0.224
171
+ - 0.225
172
+ evaluation:
173
+ eval_period_iterations: 12500
174
+ low_freq_every: 5
175
+ config_files: # Must be in fairvit/eval/configs
176
+ high_freq: benchmark_high_frequency.yaml # More often
177
+ low_freq: benchmark_low_frequency.yaml # Less often
178
+ checkpointing:
179
+ period: 3750
180
+ max_to_keep: 3
181
+ keep_every: 99999999999999999 # Save a checkpoint every N iterations, regardless of max_to_keep and period
182
+
183
+ # Example of constant schedules with schedules v2
184
+ # # schedules:
185
+ # # lr:
186
+ # # start: 0.0
187
+ # # peak: 1e-3
188
+ # # end: 1e-6
189
+ # # warmup_epochs: 10
190
+ # # freeze_last_layer_epochs: 1
191
+ # # weight_decay:
192
+ # # start: 0.04
193
+ # # peak: 0.04
194
+ # # end: 0.04
195
+ # # warmup_epochs: 0
196
+ # # momentum:
197
+ # # start: 0.992
198
+ # # peak: 0.992
199
+ # # end: 0.992
200
+ # # warmup_epochs: 0
201
+ # # teacher_temp:
202
+ # # start: 0.04
203
+ # # peak: 0.07
204
+ # # end: 0.07
205
+ # # warmup_epochs: 30