guyuchao commited on
Commit
ce04e0d
·
verified ·
1 Parent(s): ac6ce73

Add files using upload-large-folder tool

Browse files
Files changed (50) hide show
  1. torch_hub/facebookresearch_dinov2_main/dinov2/eval/segmentation/__init__.py +4 -0
  2. torch_hub/facebookresearch_dinov2_main/dinov2/eval/segmentation_m2f/core/box/samplers/base_sampler.py +92 -0
  3. torch_hub/facebookresearch_dinov2_main/dinov2/eval/segmentation_m2f/models/__init__.py +11 -0
  4. torch_hub/facebookresearch_dinov2_main/dinov2/eval/segmentation_m2f/models/backbones/adapter_modules.py +442 -0
  5. torch_hub/facebookresearch_dinov2_main/dinov2/eval/segmentation_m2f/models/backbones/drop_path.py +32 -0
  6. torch_hub/facebookresearch_dinov2_main/dinov2/eval/segmentation_m2f/models/backbones/vit_adapter.py +217 -0
  7. torch_hub/facebookresearch_dinov2_main/dinov2/eval/segmentation_m2f/models/decode_heads/mask2former_head.py +544 -0
  8. torch_hub/facebookresearch_dinov2_main/dinov2/eval/segmentation_m2f/models/utils/point_sample.py +86 -0
  9. torch_hub/facebookresearch_dinov2_main/dinov2/hub/__pycache__/__init__.cpython-310.pyc +0 -0
  10. torch_hub/facebookresearch_dinov2_main/dinov2/hub/__pycache__/backbones.cpython-310.pyc +0 -0
  11. torch_hub/facebookresearch_dinov2_main/dinov2/hub/__pycache__/classifiers.cpython-310.pyc +0 -0
  12. torch_hub/facebookresearch_dinov2_main/dinov2/hub/__pycache__/dinotxt.cpython-310.pyc +0 -0
  13. torch_hub/facebookresearch_dinov2_main/dinov2/hub/__pycache__/utils.cpython-310.pyc +0 -0
  14. torch_hub/facebookresearch_dinov2_main/dinov2/hub/depth/__init__.py +7 -0
  15. torch_hub/facebookresearch_dinov2_main/dinov2/hub/depth/__pycache__/ops.cpython-310.pyc +0 -0
  16. torch_hub/facebookresearch_dinov2_main/dinov2/hub/depth/decode_heads.py +747 -0
  17. torch_hub/facebookresearch_dinov2_main/dinov2/hub/depth/encoder_decoder.py +351 -0
  18. torch_hub/facebookresearch_dinov2_main/dinov2/hub/depth/ops.py +28 -0
  19. torch_hub/facebookresearch_dinov2_main/dinov2/hub/utils.py +39 -0
  20. torch_hub/facebookresearch_dinov2_main/dinov2/hub/xray_dino/__pycache__/backbones.cpython-310.pyc +0 -0
  21. torch_hub/facebookresearch_dinov2_main/dinov2/hub/xray_dino/backbones.py +28 -0
  22. torch_hub/facebookresearch_dinov2_main/dinov2/layers/__init__.py +12 -0
  23. torch_hub/facebookresearch_dinov2_main/dinov2/layers/__pycache__/dino_head.cpython-310.pyc +0 -0
  24. torch_hub/facebookresearch_dinov2_main/dinov2/layers/__pycache__/layer_scale.cpython-310.pyc +0 -0
  25. torch_hub/facebookresearch_dinov2_main/dinov2/layers/__pycache__/swiglu_ffn.cpython-310.pyc +0 -0
  26. torch_hub/facebookresearch_dinov2_main/dinov2/layers/attention.py +99 -0
  27. torch_hub/facebookresearch_dinov2_main/dinov2/layers/patch_embed.py +88 -0
  28. torch_hub/facebookresearch_dinov2_main/dinov2/logging/__init__.py +102 -0
  29. torch_hub/facebookresearch_dinov2_main/dinov2/logging/helpers.py +194 -0
  30. torch_hub/facebookresearch_dinov2_main/dinov2/loss/__init__.py +8 -0
  31. torch_hub/facebookresearch_dinov2_main/dinov2/loss/ibot_patch_loss.py +151 -0
  32. torch_hub/facebookresearch_dinov2_main/dinov2/models/vision_transformer.py +428 -0
  33. torch_hub/facebookresearch_dinov2_main/dinov2/run/__init__.py +4 -0
  34. torch_hub/facebookresearch_dinov2_main/dinov2/run/eval/cell_dino/knn.py +59 -0
  35. torch_hub/facebookresearch_dinov2_main/dinov2/run/eval/cell_dino/linear.py +59 -0
  36. torch_hub/facebookresearch_dinov2_main/dinov2/run/submit.py +122 -0
  37. torch_hub/facebookresearch_dinov2_main/dinov2/run/train/train.py +59 -0
  38. torch_hub/facebookresearch_dinov2_main/dinov2/thirdparty/CLIP/LICENSE +21 -0
  39. torch_hub/facebookresearch_dinov2_main/dinov2/thirdparty/CLIP/clip/simple_tokenizer.py +135 -0
  40. torch_hub/facebookresearch_dinov2_main/dinov2/utils/__init__.py +4 -0
  41. torch_hub/facebookresearch_dinov2_main/dinov2/utils/checkpoint.py +63 -0
  42. torch_hub/facebookresearch_dinov2_main/dinov2/utils/cluster.py +95 -0
  43. torch_hub/facebookresearch_dinov2_main/dinov2/utils/config.py +72 -0
  44. torch_hub/facebookresearch_dinov2_main/dinov2/utils/param_groups.py +103 -0
  45. torch_hub/facebookresearch_dinov2_main/docs/README_CHANNEL_ADAPTIVE_DINO.md +156 -0
  46. torch_hub/facebookresearch_dinov2_main/notebooks/cell_dino/inference.ipynb +179 -0
  47. torch_hub/facebookresearch_dinov2_main/notebooks/depth_estimation.ipynb +0 -0
  48. torch_hub/facebookresearch_dinov2_main/notebooks/semantic_segmentation.ipynb +0 -0
  49. torch_hub/facebookresearch_dinov2_main/scripts/cell_dino/launcher_knn_eval_on_chammi.sh +34 -0
  50. torch_hub/trusted_list +0 -0
torch_hub/facebookresearch_dinov2_main/dinov2/eval/segmentation/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
torch_hub/facebookresearch_dinov2_main/dinov2/eval/segmentation_m2f/core/box/samplers/base_sampler.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ from abc import ABCMeta, abstractmethod
7
+
8
+ import torch
9
+
10
+ from .sampling_result import SamplingResult
11
+
12
+
13
+ class BaseSampler(metaclass=ABCMeta):
14
+ """Base class of samplers."""
15
+
16
+ def __init__(self, num, pos_fraction, neg_pos_ub=-1, add_gt_as_proposals=True, **kwargs):
17
+ self.num = num
18
+ self.pos_fraction = pos_fraction
19
+ self.neg_pos_ub = neg_pos_ub
20
+ self.add_gt_as_proposals = add_gt_as_proposals
21
+ self.pos_sampler = self
22
+ self.neg_sampler = self
23
+
24
+ @abstractmethod
25
+ def _sample_pos(self, assign_result, num_expected, **kwargs):
26
+ """Sample positive samples."""
27
+ pass
28
+
29
+ @abstractmethod
30
+ def _sample_neg(self, assign_result, num_expected, **kwargs):
31
+ """Sample negative samples."""
32
+ pass
33
+
34
+ def sample(self, assign_result, bboxes, gt_bboxes, gt_labels=None, **kwargs):
35
+ """Sample positive and negative bboxes.
36
+
37
+ This is a simple implementation of bbox sampling given candidates,
38
+ assigning results and ground truth bboxes.
39
+
40
+ Args:
41
+ assign_result (:obj:`AssignResult`): Bbox assigning results.
42
+ bboxes (Tensor): Boxes to be sampled from.
43
+ gt_bboxes (Tensor): Ground truth bboxes.
44
+ gt_labels (Tensor, optional): Class labels of ground truth bboxes.
45
+
46
+ Returns:
47
+ :obj:`SamplingResult`: Sampling result.
48
+
49
+ Example:
50
+ >>> from mmdet.core.bbox import RandomSampler
51
+ >>> from mmdet.core.bbox import AssignResult
52
+ >>> from mmdet.core.bbox.demodata import ensure_rng, random_boxes
53
+ >>> rng = ensure_rng(None)
54
+ >>> assign_result = AssignResult.random(rng=rng)
55
+ >>> bboxes = random_boxes(assign_result.num_preds, rng=rng)
56
+ >>> gt_bboxes = random_boxes(assign_result.num_gts, rng=rng)
57
+ >>> gt_labels = None
58
+ >>> self = RandomSampler(num=32, pos_fraction=0.5, neg_pos_ub=-1,
59
+ >>> add_gt_as_proposals=False)
60
+ >>> self = self.sample(assign_result, bboxes, gt_bboxes, gt_labels)
61
+ """
62
+ if len(bboxes.shape) < 2:
63
+ bboxes = bboxes[None, :]
64
+
65
+ bboxes = bboxes[:, :4]
66
+
67
+ gt_flags = bboxes.new_zeros((bboxes.shape[0],), dtype=torch.uint8)
68
+ if self.add_gt_as_proposals and len(gt_bboxes) > 0:
69
+ if gt_labels is None:
70
+ raise ValueError("gt_labels must be given when add_gt_as_proposals is True")
71
+ bboxes = torch.cat([gt_bboxes, bboxes], dim=0)
72
+ assign_result.add_gt_(gt_labels)
73
+ gt_ones = bboxes.new_ones(gt_bboxes.shape[0], dtype=torch.uint8)
74
+ gt_flags = torch.cat([gt_ones, gt_flags])
75
+
76
+ num_expected_pos = int(self.num * self.pos_fraction)
77
+ pos_inds = self.pos_sampler._sample_pos(assign_result, num_expected_pos, bboxes=bboxes, **kwargs)
78
+ # We found that sampled indices have duplicated items occasionally.
79
+ # (may be a bug of PyTorch)
80
+ pos_inds = pos_inds.unique()
81
+ num_sampled_pos = pos_inds.numel()
82
+ num_expected_neg = self.num - num_sampled_pos
83
+ if self.neg_pos_ub >= 0:
84
+ _pos = max(1, num_sampled_pos)
85
+ neg_upper_bound = int(self.neg_pos_ub * _pos)
86
+ if num_expected_neg > neg_upper_bound:
87
+ num_expected_neg = neg_upper_bound
88
+ neg_inds = self.neg_sampler._sample_neg(assign_result, num_expected_neg, bboxes=bboxes, **kwargs)
89
+ neg_inds = neg_inds.unique()
90
+
91
+ sampling_result = SamplingResult(pos_inds, neg_inds, bboxes, gt_bboxes, assign_result, gt_flags)
92
+ return sampling_result
torch_hub/facebookresearch_dinov2_main/dinov2/eval/segmentation_m2f/models/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ from .backbones import * # noqa: F403
7
+ from .builder import MASK_ASSIGNERS, MATCH_COST, TRANSFORMER, build_assigner, build_match_cost
8
+ from .decode_heads import * # noqa: F403
9
+ from .losses import * # noqa: F403
10
+ from .plugins import * # noqa: F403
11
+ from .segmentors import * # noqa: F403
torch_hub/facebookresearch_dinov2_main/dinov2/eval/segmentation_m2f/models/backbones/adapter_modules.py ADDED
@@ -0,0 +1,442 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ from functools import partial
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.utils.checkpoint as cp
11
+
12
+ from ...ops.modules import MSDeformAttn
13
+ from .drop_path import DropPath
14
+
15
+
16
+ def get_reference_points(spatial_shapes, device):
17
+ reference_points_list = []
18
+ for lvl, (H_, W_) in enumerate(spatial_shapes):
19
+ ref_y, ref_x = torch.meshgrid(
20
+ torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device),
21
+ torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device),
22
+ )
23
+ ref_y = ref_y.reshape(-1)[None] / H_
24
+ ref_x = ref_x.reshape(-1)[None] / W_
25
+ ref = torch.stack((ref_x, ref_y), -1)
26
+ reference_points_list.append(ref)
27
+ reference_points = torch.cat(reference_points_list, 1)
28
+ reference_points = reference_points[:, :, None]
29
+ return reference_points
30
+
31
+
32
+ def deform_inputs(x, patch_size):
33
+ bs, c, h, w = x.shape
34
+ spatial_shapes = torch.as_tensor(
35
+ [(h // 8, w // 8), (h // 16, w // 16), (h // 32, w // 32)], dtype=torch.long, device=x.device
36
+ )
37
+ level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]))
38
+ reference_points = get_reference_points([(h // patch_size, w // patch_size)], x.device)
39
+ deform_inputs1 = [reference_points, spatial_shapes, level_start_index]
40
+
41
+ spatial_shapes = torch.as_tensor([(h // patch_size, w // patch_size)], dtype=torch.long, device=x.device)
42
+ level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]))
43
+ reference_points = get_reference_points([(h // 8, w // 8), (h // 16, w // 16), (h // 32, w // 32)], x.device)
44
+ deform_inputs2 = [reference_points, spatial_shapes, level_start_index]
45
+
46
+ return deform_inputs1, deform_inputs2
47
+
48
+
49
+ class ConvFFN(nn.Module):
50
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0):
51
+ super().__init__()
52
+ out_features = out_features or in_features
53
+ hidden_features = hidden_features or in_features
54
+ self.fc1 = nn.Linear(in_features, hidden_features)
55
+ self.dwconv = DWConv(hidden_features)
56
+ self.act = act_layer()
57
+ self.fc2 = nn.Linear(hidden_features, out_features)
58
+ self.drop = nn.Dropout(drop)
59
+
60
+ def forward(self, x, H, W):
61
+ x = self.fc1(x)
62
+ x = self.dwconv(x, H, W)
63
+ x = self.act(x)
64
+ x = self.drop(x)
65
+ x = self.fc2(x)
66
+ x = self.drop(x)
67
+ return x
68
+
69
+
70
+ class DWConv(nn.Module):
71
+ def __init__(self, dim=768):
72
+ super().__init__()
73
+ self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
74
+
75
+ def forward(self, x, H, W):
76
+ B, N, C = x.shape
77
+ n = N // 21
78
+ x1 = x[:, 0 : 16 * n, :].transpose(1, 2).view(B, C, H * 2, W * 2).contiguous()
79
+ x2 = x[:, 16 * n : 20 * n, :].transpose(1, 2).view(B, C, H, W).contiguous()
80
+ x3 = x[:, 20 * n :, :].transpose(1, 2).view(B, C, H // 2, W // 2).contiguous()
81
+ x1 = self.dwconv(x1).flatten(2).transpose(1, 2)
82
+ x2 = self.dwconv(x2).flatten(2).transpose(1, 2)
83
+ x3 = self.dwconv(x3).flatten(2).transpose(1, 2)
84
+ x = torch.cat([x1, x2, x3], dim=1)
85
+ return x
86
+
87
+
88
+ class Extractor(nn.Module):
89
+ def __init__(
90
+ self,
91
+ dim,
92
+ num_heads=6,
93
+ n_points=4,
94
+ n_levels=1,
95
+ deform_ratio=1.0,
96
+ with_cffn=True,
97
+ cffn_ratio=0.25,
98
+ drop=0.0,
99
+ drop_path=0.0,
100
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
101
+ with_cp=False,
102
+ ):
103
+ super().__init__()
104
+ self.query_norm = norm_layer(dim)
105
+ self.feat_norm = norm_layer(dim)
106
+ self.attn = MSDeformAttn(
107
+ d_model=dim, n_levels=n_levels, n_heads=num_heads, n_points=n_points, ratio=deform_ratio
108
+ )
109
+ self.with_cffn = with_cffn
110
+ self.with_cp = with_cp
111
+ if with_cffn:
112
+ self.ffn = ConvFFN(in_features=dim, hidden_features=int(dim * cffn_ratio), drop=drop)
113
+ self.ffn_norm = norm_layer(dim)
114
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
115
+
116
+ def forward(self, query, reference_points, feat, spatial_shapes, level_start_index, H, W):
117
+ def _inner_forward(query, feat):
118
+
119
+ attn = self.attn(
120
+ self.query_norm(query), reference_points, self.feat_norm(feat), spatial_shapes, level_start_index, None
121
+ )
122
+ query = query + attn
123
+
124
+ if self.with_cffn:
125
+ query = query + self.drop_path(self.ffn(self.ffn_norm(query), H, W))
126
+ return query
127
+
128
+ if self.with_cp and query.requires_grad:
129
+ query = cp.checkpoint(_inner_forward, query, feat)
130
+ else:
131
+ query = _inner_forward(query, feat)
132
+
133
+ return query
134
+
135
+
136
+ class Injector(nn.Module):
137
+ def __init__(
138
+ self,
139
+ dim,
140
+ num_heads=6,
141
+ n_points=4,
142
+ n_levels=1,
143
+ deform_ratio=1.0,
144
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
145
+ init_values=0.0,
146
+ with_cp=False,
147
+ ):
148
+ super().__init__()
149
+ self.with_cp = with_cp
150
+ self.query_norm = norm_layer(dim)
151
+ self.feat_norm = norm_layer(dim)
152
+ self.attn = MSDeformAttn(
153
+ d_model=dim, n_levels=n_levels, n_heads=num_heads, n_points=n_points, ratio=deform_ratio
154
+ )
155
+ self.gamma = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
156
+
157
+ def forward(self, query, reference_points, feat, spatial_shapes, level_start_index):
158
+ def _inner_forward(query, feat):
159
+
160
+ attn = self.attn(
161
+ self.query_norm(query), reference_points, self.feat_norm(feat), spatial_shapes, level_start_index, None
162
+ )
163
+ return query + self.gamma * attn
164
+
165
+ if self.with_cp and query.requires_grad:
166
+ query = cp.checkpoint(_inner_forward, query, feat)
167
+ else:
168
+ query = _inner_forward(query, feat)
169
+
170
+ return query
171
+
172
+
173
+ class InteractionBlock(nn.Module):
174
+ def __init__(
175
+ self,
176
+ dim,
177
+ num_heads=6,
178
+ n_points=4,
179
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
180
+ drop=0.0,
181
+ drop_path=0.0,
182
+ with_cffn=True,
183
+ cffn_ratio=0.25,
184
+ init_values=0.0,
185
+ deform_ratio=1.0,
186
+ extra_extractor=False,
187
+ with_cp=False,
188
+ ):
189
+ super().__init__()
190
+
191
+ self.injector = Injector(
192
+ dim=dim,
193
+ n_levels=3,
194
+ num_heads=num_heads,
195
+ init_values=init_values,
196
+ n_points=n_points,
197
+ norm_layer=norm_layer,
198
+ deform_ratio=deform_ratio,
199
+ with_cp=with_cp,
200
+ )
201
+ self.extractor = Extractor(
202
+ dim=dim,
203
+ n_levels=1,
204
+ num_heads=num_heads,
205
+ n_points=n_points,
206
+ norm_layer=norm_layer,
207
+ deform_ratio=deform_ratio,
208
+ with_cffn=with_cffn,
209
+ cffn_ratio=cffn_ratio,
210
+ drop=drop,
211
+ drop_path=drop_path,
212
+ with_cp=with_cp,
213
+ )
214
+ if extra_extractor:
215
+ self.extra_extractors = nn.Sequential(
216
+ *[
217
+ Extractor(
218
+ dim=dim,
219
+ num_heads=num_heads,
220
+ n_points=n_points,
221
+ norm_layer=norm_layer,
222
+ with_cffn=with_cffn,
223
+ cffn_ratio=cffn_ratio,
224
+ deform_ratio=deform_ratio,
225
+ drop=drop,
226
+ drop_path=drop_path,
227
+ with_cp=with_cp,
228
+ )
229
+ for _ in range(2)
230
+ ]
231
+ )
232
+ else:
233
+ self.extra_extractors = None
234
+
235
+ def forward(self, x, c, blocks, deform_inputs1, deform_inputs2, H_c, W_c, H_toks, W_toks):
236
+ x = self.injector(
237
+ query=x,
238
+ reference_points=deform_inputs1[0],
239
+ feat=c,
240
+ spatial_shapes=deform_inputs1[1],
241
+ level_start_index=deform_inputs1[2],
242
+ )
243
+ for idx, blk in enumerate(blocks):
244
+ x = blk(x, H_toks, W_toks)
245
+ c = self.extractor(
246
+ query=c,
247
+ reference_points=deform_inputs2[0],
248
+ feat=x,
249
+ spatial_shapes=deform_inputs2[1],
250
+ level_start_index=deform_inputs2[2],
251
+ H=H_c,
252
+ W=W_c,
253
+ )
254
+ if self.extra_extractors is not None:
255
+ for extractor in self.extra_extractors:
256
+ c = extractor(
257
+ query=c,
258
+ reference_points=deform_inputs2[0],
259
+ feat=x,
260
+ spatial_shapes=deform_inputs2[1],
261
+ level_start_index=deform_inputs2[2],
262
+ H=H_c,
263
+ W=W_c,
264
+ )
265
+ return x, c
266
+
267
+
268
+ class InteractionBlockWithCls(nn.Module):
269
+ def __init__(
270
+ self,
271
+ dim,
272
+ num_heads=6,
273
+ n_points=4,
274
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
275
+ drop=0.0,
276
+ drop_path=0.0,
277
+ with_cffn=True,
278
+ cffn_ratio=0.25,
279
+ init_values=0.0,
280
+ deform_ratio=1.0,
281
+ extra_extractor=False,
282
+ with_cp=False,
283
+ ):
284
+ super().__init__()
285
+
286
+ self.injector = Injector(
287
+ dim=dim,
288
+ n_levels=3,
289
+ num_heads=num_heads,
290
+ init_values=init_values,
291
+ n_points=n_points,
292
+ norm_layer=norm_layer,
293
+ deform_ratio=deform_ratio,
294
+ with_cp=with_cp,
295
+ )
296
+ self.extractor = Extractor(
297
+ dim=dim,
298
+ n_levels=1,
299
+ num_heads=num_heads,
300
+ n_points=n_points,
301
+ norm_layer=norm_layer,
302
+ deform_ratio=deform_ratio,
303
+ with_cffn=with_cffn,
304
+ cffn_ratio=cffn_ratio,
305
+ drop=drop,
306
+ drop_path=drop_path,
307
+ with_cp=with_cp,
308
+ )
309
+ if extra_extractor:
310
+ self.extra_extractors = nn.Sequential(
311
+ *[
312
+ Extractor(
313
+ dim=dim,
314
+ num_heads=num_heads,
315
+ n_points=n_points,
316
+ norm_layer=norm_layer,
317
+ with_cffn=with_cffn,
318
+ cffn_ratio=cffn_ratio,
319
+ deform_ratio=deform_ratio,
320
+ drop=drop,
321
+ drop_path=drop_path,
322
+ with_cp=with_cp,
323
+ )
324
+ for _ in range(2)
325
+ ]
326
+ )
327
+ else:
328
+ self.extra_extractors = None
329
+
330
+ def forward(self, x, c, cls, blocks, deform_inputs1, deform_inputs2, H_c, W_c, H_toks, W_toks):
331
+ x = self.injector(
332
+ query=x,
333
+ reference_points=deform_inputs1[0],
334
+ feat=c,
335
+ spatial_shapes=deform_inputs1[1],
336
+ level_start_index=deform_inputs1[2],
337
+ )
338
+ x = torch.cat((cls, x), dim=1)
339
+ for idx, blk in enumerate(blocks):
340
+ x = blk(x, H_toks, W_toks)
341
+ cls, x = (
342
+ x[
343
+ :,
344
+ :1,
345
+ ],
346
+ x[
347
+ :,
348
+ 1:,
349
+ ],
350
+ )
351
+ c = self.extractor(
352
+ query=c,
353
+ reference_points=deform_inputs2[0],
354
+ feat=x,
355
+ spatial_shapes=deform_inputs2[1],
356
+ level_start_index=deform_inputs2[2],
357
+ H=H_c,
358
+ W=W_c,
359
+ )
360
+ if self.extra_extractors is not None:
361
+ for extractor in self.extra_extractors:
362
+ c = extractor(
363
+ query=c,
364
+ reference_points=deform_inputs2[0],
365
+ feat=x,
366
+ spatial_shapes=deform_inputs2[1],
367
+ level_start_index=deform_inputs2[2],
368
+ H=H_c,
369
+ W=W_c,
370
+ )
371
+ return x, c, cls
372
+
373
+
374
+ class SpatialPriorModule(nn.Module):
375
+ def __init__(self, inplanes=64, embed_dim=384, with_cp=False):
376
+ super().__init__()
377
+ self.with_cp = with_cp
378
+
379
+ self.stem = nn.Sequential(
380
+ *[
381
+ nn.Conv2d(3, inplanes, kernel_size=3, stride=2, padding=1, bias=False),
382
+ nn.SyncBatchNorm(inplanes),
383
+ nn.ReLU(inplace=True),
384
+ nn.Conv2d(inplanes, inplanes, kernel_size=3, stride=1, padding=1, bias=False),
385
+ nn.SyncBatchNorm(inplanes),
386
+ nn.ReLU(inplace=True),
387
+ nn.Conv2d(inplanes, inplanes, kernel_size=3, stride=1, padding=1, bias=False),
388
+ nn.SyncBatchNorm(inplanes),
389
+ nn.ReLU(inplace=True),
390
+ nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
391
+ ]
392
+ )
393
+ self.conv2 = nn.Sequential(
394
+ *[
395
+ nn.Conv2d(inplanes, 2 * inplanes, kernel_size=3, stride=2, padding=1, bias=False),
396
+ nn.SyncBatchNorm(2 * inplanes),
397
+ nn.ReLU(inplace=True),
398
+ ]
399
+ )
400
+ self.conv3 = nn.Sequential(
401
+ *[
402
+ nn.Conv2d(2 * inplanes, 4 * inplanes, kernel_size=3, stride=2, padding=1, bias=False),
403
+ nn.SyncBatchNorm(4 * inplanes),
404
+ nn.ReLU(inplace=True),
405
+ ]
406
+ )
407
+ self.conv4 = nn.Sequential(
408
+ *[
409
+ nn.Conv2d(4 * inplanes, 4 * inplanes, kernel_size=3, stride=2, padding=1, bias=False),
410
+ nn.SyncBatchNorm(4 * inplanes),
411
+ nn.ReLU(inplace=True),
412
+ ]
413
+ )
414
+ self.fc1 = nn.Conv2d(inplanes, embed_dim, kernel_size=1, stride=1, padding=0, bias=True)
415
+ self.fc2 = nn.Conv2d(2 * inplanes, embed_dim, kernel_size=1, stride=1, padding=0, bias=True)
416
+ self.fc3 = nn.Conv2d(4 * inplanes, embed_dim, kernel_size=1, stride=1, padding=0, bias=True)
417
+ self.fc4 = nn.Conv2d(4 * inplanes, embed_dim, kernel_size=1, stride=1, padding=0, bias=True)
418
+
419
+ def forward(self, x):
420
+ def _inner_forward(x):
421
+ c1 = self.stem(x)
422
+ c2 = self.conv2(c1)
423
+ c3 = self.conv3(c2)
424
+ c4 = self.conv4(c3)
425
+ c1 = self.fc1(c1)
426
+ c2 = self.fc2(c2)
427
+ c3 = self.fc3(c3)
428
+ c4 = self.fc4(c4)
429
+
430
+ bs, dim, _, _ = c1.shape
431
+ # c1 = c1.view(bs, dim, -1).transpose(1, 2) # 4s
432
+ c2 = c2.view(bs, dim, -1).transpose(1, 2) # 8s
433
+ c3 = c3.view(bs, dim, -1).transpose(1, 2) # 16s
434
+ c4 = c4.view(bs, dim, -1).transpose(1, 2) # 32s
435
+
436
+ return c1, c2, c3, c4
437
+
438
+ if self.with_cp and x.requires_grad:
439
+ outs = cp.checkpoint(_inner_forward, x)
440
+ else:
441
+ outs = _inner_forward(x)
442
+ return outs
torch_hub/facebookresearch_dinov2_main/dinov2/eval/segmentation_m2f/models/backbones/drop_path.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
9
+
10
+ from torch import nn
11
+
12
+
13
+ def drop_path(x, drop_prob: float = 0.0, training: bool = False):
14
+ if drop_prob == 0.0 or not training:
15
+ return x
16
+ keep_prob = 1 - drop_prob
17
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
18
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
19
+ if keep_prob > 0.0:
20
+ random_tensor.div_(keep_prob)
21
+ return x * random_tensor
22
+
23
+
24
+ class DropPath(nn.Module):
25
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
26
+
27
+ def __init__(self, drop_prob: float = 0.0):
28
+ super(DropPath, self).__init__()
29
+ self.drop_prob = drop_prob
30
+
31
+ def forward(self, x):
32
+ return drop_path(x, self.drop_prob, self.training)
torch_hub/facebookresearch_dinov2_main/dinov2/eval/segmentation_m2f/models/backbones/vit_adapter.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import math
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from mmseg.models.builder import BACKBONES
12
+ from torch.nn.init import normal_
13
+
14
+ from ...ops.modules import MSDeformAttn
15
+ from .adapter_modules import InteractionBlock, InteractionBlockWithCls, SpatialPriorModule, deform_inputs
16
+ from .vit import TIMMVisionTransformer
17
+
18
+
19
+ @BACKBONES.register_module()
20
+ class ViTAdapter(TIMMVisionTransformer):
21
+ def __init__(
22
+ self,
23
+ pretrain_size=224,
24
+ num_heads=12,
25
+ conv_inplane=64,
26
+ n_points=4,
27
+ deform_num_heads=6,
28
+ init_values=0.0,
29
+ interaction_indexes=None,
30
+ with_cffn=True,
31
+ cffn_ratio=0.25,
32
+ deform_ratio=1.0,
33
+ add_vit_feature=True,
34
+ pretrained=None,
35
+ use_extra_extractor=True,
36
+ freeze_vit=False,
37
+ use_cls=True,
38
+ with_cp=False,
39
+ *args,
40
+ **kwargs
41
+ ):
42
+
43
+ super().__init__(num_heads=num_heads, pretrained=pretrained, with_cp=with_cp, *args, **kwargs)
44
+ if freeze_vit:
45
+ for param in self.parameters():
46
+ param.requires_grad = False
47
+
48
+ # self.num_classes = 80
49
+ self.use_cls = use_cls
50
+ if not self.use_cls:
51
+ self.cls_token = None
52
+ self.num_block = len(self.blocks)
53
+ self.pretrain_size = (pretrain_size, pretrain_size)
54
+ self.interaction_indexes = interaction_indexes
55
+ self.add_vit_feature = add_vit_feature
56
+ embed_dim = self.embed_dim
57
+
58
+ block_fn = InteractionBlockWithCls if use_cls else InteractionBlock
59
+
60
+ self.level_embed = nn.Parameter(torch.zeros(3, embed_dim))
61
+ self.spm = SpatialPriorModule(inplanes=conv_inplane, embed_dim=embed_dim, with_cp=False)
62
+ self.interactions = nn.Sequential(
63
+ *[
64
+ block_fn(
65
+ dim=embed_dim,
66
+ num_heads=deform_num_heads,
67
+ n_points=n_points,
68
+ init_values=init_values,
69
+ drop_path=self.drop_path_rate,
70
+ norm_layer=self.norm_layer,
71
+ with_cffn=with_cffn,
72
+ cffn_ratio=cffn_ratio,
73
+ deform_ratio=deform_ratio,
74
+ extra_extractor=((True if i == len(interaction_indexes) - 1 else False) and use_extra_extractor),
75
+ with_cp=with_cp,
76
+ )
77
+ for i in range(len(interaction_indexes))
78
+ ]
79
+ )
80
+ self.up = nn.ConvTranspose2d(embed_dim, embed_dim, 2, 2)
81
+ self.norm1 = nn.SyncBatchNorm(embed_dim)
82
+ self.norm2 = nn.SyncBatchNorm(embed_dim)
83
+ self.norm3 = nn.SyncBatchNorm(embed_dim)
84
+ self.norm4 = nn.SyncBatchNorm(embed_dim)
85
+
86
+ self.up.apply(self._init_weights)
87
+ self.spm.apply(self._init_weights)
88
+ self.interactions.apply(self._init_weights)
89
+ self.apply(self._init_deform_weights)
90
+ normal_(self.level_embed)
91
+
92
+ def _init_weights(self, m):
93
+ if isinstance(m, nn.Linear):
94
+ torch.nn.init.trunc_normal_(m.weight, std=0.02)
95
+ if isinstance(m, nn.Linear) and m.bias is not None:
96
+ nn.init.constant_(m.bias, 0)
97
+ elif isinstance(m, nn.LayerNorm) or isinstance(m, nn.BatchNorm2d):
98
+ nn.init.constant_(m.bias, 0)
99
+ nn.init.constant_(m.weight, 1.0)
100
+ elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
101
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
102
+ fan_out //= m.groups
103
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
104
+ if m.bias is not None:
105
+ m.bias.data.zero_()
106
+
107
+ def _get_pos_embed(self, pos_embed, H, W):
108
+ pos_embed = pos_embed.reshape(
109
+ 1, self.pretrain_size[0] // self.patch_size, self.pretrain_size[1] // self.patch_size, -1
110
+ ).permute(0, 3, 1, 2)
111
+ pos_embed = (
112
+ F.interpolate(pos_embed, size=(H, W), mode="bicubic", align_corners=False)
113
+ .reshape(1, -1, H * W)
114
+ .permute(0, 2, 1)
115
+ )
116
+ return pos_embed
117
+
118
+ def _init_deform_weights(self, m):
119
+ if isinstance(m, MSDeformAttn):
120
+ m._reset_parameters()
121
+
122
+ def _add_level_embed(self, c2, c3, c4):
123
+ c2 = c2 + self.level_embed[0]
124
+ c3 = c3 + self.level_embed[1]
125
+ c4 = c4 + self.level_embed[2]
126
+ return c2, c3, c4
127
+
128
+ def forward(self, x):
129
+ deform_inputs1, deform_inputs2 = deform_inputs(x, self.patch_size)
130
+
131
+ # SPM forward
132
+ c1, c2, c3, c4 = self.spm(x)
133
+ c2, c3, c4 = self._add_level_embed(c2, c3, c4)
134
+ c = torch.cat([c2, c3, c4], dim=1)
135
+
136
+ # Patch Embedding forward
137
+ H_c, W_c = x.shape[2] // 16, x.shape[3] // 16
138
+ x, H_toks, W_toks = self.patch_embed(x)
139
+ # print("H_toks, W_toks =", H_toks, W_toks)
140
+ bs, n, dim = x.shape
141
+ pos_embed = self._get_pos_embed(self.pos_embed[:, 1:], H_toks, W_toks)
142
+ if self.use_cls:
143
+ cls_token = self.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks
144
+ x = torch.cat((cls_token, x), dim=1)
145
+ pos_embed = torch.cat((self.pos_embed[:, :1], pos_embed), dim=1)
146
+ x = self.pos_drop(x + pos_embed)
147
+ # For CLIP
148
+ x = self.norm_pre(x)
149
+
150
+ # Interaction
151
+ if self.use_cls:
152
+ cls, x = (
153
+ x[
154
+ :,
155
+ :1,
156
+ ],
157
+ x[
158
+ :,
159
+ 1:,
160
+ ],
161
+ )
162
+ outs = list()
163
+ for i, layer in enumerate(self.interactions):
164
+ indexes = self.interaction_indexes[i]
165
+ if self.use_cls:
166
+ x, c, cls = layer(
167
+ x,
168
+ c,
169
+ cls,
170
+ self.blocks[indexes[0] : indexes[-1] + 1],
171
+ deform_inputs1,
172
+ deform_inputs2,
173
+ H_c,
174
+ W_c,
175
+ H_toks,
176
+ W_toks,
177
+ )
178
+ else:
179
+ x, c = layer(
180
+ x,
181
+ c,
182
+ self.blocks[indexes[0] : indexes[-1] + 1],
183
+ deform_inputs1,
184
+ deform_inputs2,
185
+ H_c,
186
+ W_c,
187
+ H_toks,
188
+ W_toks,
189
+ )
190
+ outs.append(x.transpose(1, 2).view(bs, dim, H_toks, W_toks).contiguous())
191
+
192
+ # Split & Reshape
193
+ c2 = c[:, 0 : c2.size(1), :]
194
+ c3 = c[:, c2.size(1) : c2.size(1) + c3.size(1), :]
195
+ c4 = c[:, c2.size(1) + c3.size(1) :, :]
196
+
197
+ c2 = c2.transpose(1, 2).view(bs, dim, H_c * 2, W_c * 2).contiguous()
198
+ c3 = c3.transpose(1, 2).view(bs, dim, H_c, W_c).contiguous()
199
+ c4 = c4.transpose(1, 2).view(bs, dim, H_c // 2, W_c // 2).contiguous()
200
+ c1 = self.up(c2) + c1
201
+
202
+ if self.add_vit_feature:
203
+ x1, x2, x3, x4 = outs
204
+
205
+ x1 = F.interpolate(x1, size=(4 * H_c, 4 * W_c), mode="bilinear", align_corners=False)
206
+ x2 = F.interpolate(x2, size=(2 * H_c, 2 * W_c), mode="bilinear", align_corners=False)
207
+ x3 = F.interpolate(x3, size=(1 * H_c, 1 * W_c), mode="bilinear", align_corners=False)
208
+ x4 = F.interpolate(x4, size=(H_c // 2, W_c // 2), mode="bilinear", align_corners=False)
209
+ # print(c1.shape, c2.shape, c3.shape, c4.shape, x1.shape, x2.shape, x3.shape, x4.shape, H_c, H_toks)
210
+ c1, c2, c3, c4 = c1 + x1, c2 + x2, c3 + x3, c4 + x4
211
+
212
+ # Final Norm
213
+ f1 = self.norm1(c1)
214
+ f2 = self.norm2(c2)
215
+ f3 = self.norm3(c3)
216
+ f4 = self.norm4(c4)
217
+ return [f1, f2, f3, f4]
torch_hub/facebookresearch_dinov2_main/dinov2/eval/segmentation_m2f/models/decode_heads/mask2former_head.py ADDED
@@ -0,0 +1,544 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import copy
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from mmcv.cnn import Conv2d, build_plugin_layer, caffe2_xavier_init
12
+ from mmcv.cnn.bricks.transformer import build_positional_encoding, build_transformer_layer_sequence
13
+ from mmcv.ops import point_sample
14
+ from mmcv.runner import ModuleList, force_fp32
15
+ from mmseg.models.builder import HEADS, build_loss
16
+ from mmseg.models.decode_heads.decode_head import BaseDecodeHead
17
+
18
+ from ...core import build_sampler, multi_apply, reduce_mean
19
+ from ..builder import build_assigner
20
+ from ..utils import get_uncertain_point_coords_with_randomness
21
+
22
+
23
+ @HEADS.register_module()
24
+ class Mask2FormerHead(BaseDecodeHead):
25
+ """Implements the Mask2Former head.
26
+
27
+ See `Masked-attention Mask Transformer for Universal Image
28
+ Segmentation <https://arxiv.org/pdf/2112.01527>`_ for details.
29
+
30
+ Args:
31
+ in_channels (list[int]): Number of channels in the input feature map.
32
+ feat_channels (int): Number of channels for features.
33
+ out_channels (int): Number of channels for output.
34
+ num_things_classes (int): Number of things.
35
+ num_stuff_classes (int): Number of stuff.
36
+ num_queries (int): Number of query in Transformer decoder.
37
+ pixel_decoder (:obj:`mmcv.ConfigDict` | dict): Config for pixel
38
+ decoder. Defaults to None.
39
+ enforce_decoder_input_project (bool, optional): Whether to add
40
+ a layer to change the embed_dim of tranformer encoder in
41
+ pixel decoder to the embed_dim of transformer decoder.
42
+ Defaults to False.
43
+ transformer_decoder (:obj:`mmcv.ConfigDict` | dict): Config for
44
+ transformer decoder. Defaults to None.
45
+ positional_encoding (:obj:`mmcv.ConfigDict` | dict): Config for
46
+ transformer decoder position encoding. Defaults to None.
47
+ loss_cls (:obj:`mmcv.ConfigDict` | dict): Config of the classification
48
+ loss. Defaults to None.
49
+ loss_mask (:obj:`mmcv.ConfigDict` | dict): Config of the mask loss.
50
+ Defaults to None.
51
+ loss_dice (:obj:`mmcv.ConfigDict` | dict): Config of the dice loss.
52
+ Defaults to None.
53
+ train_cfg (:obj:`mmcv.ConfigDict` | dict): Training config of
54
+ Mask2Former head.
55
+ test_cfg (:obj:`mmcv.ConfigDict` | dict): Testing config of
56
+ Mask2Former head.
57
+ init_cfg (dict or list[dict], optional): Initialization config dict.
58
+ Defaults to None.
59
+ """
60
+
61
+ def __init__(
62
+ self,
63
+ in_channels,
64
+ feat_channels,
65
+ out_channels,
66
+ num_things_classes=80,
67
+ num_stuff_classes=53,
68
+ num_queries=100,
69
+ num_transformer_feat_level=3,
70
+ pixel_decoder=None,
71
+ enforce_decoder_input_project=False,
72
+ transformer_decoder=None,
73
+ positional_encoding=None,
74
+ loss_cls=None,
75
+ loss_mask=None,
76
+ loss_dice=None,
77
+ train_cfg=None,
78
+ test_cfg=None,
79
+ init_cfg=None,
80
+ **kwargs,
81
+ ):
82
+ super(Mask2FormerHead, self).__init__(
83
+ in_channels=in_channels,
84
+ channels=feat_channels,
85
+ num_classes=(num_things_classes + num_stuff_classes),
86
+ init_cfg=init_cfg,
87
+ input_transform="multiple_select",
88
+ **kwargs,
89
+ )
90
+ self.num_things_classes = num_things_classes
91
+ self.num_stuff_classes = num_stuff_classes
92
+ self.num_classes = self.num_things_classes + self.num_stuff_classes
93
+ self.num_queries = num_queries
94
+ self.num_transformer_feat_level = num_transformer_feat_level
95
+ self.num_heads = transformer_decoder.transformerlayers.attn_cfgs.num_heads
96
+ self.num_transformer_decoder_layers = transformer_decoder.num_layers
97
+ assert pixel_decoder.encoder.transformerlayers.attn_cfgs.num_levels == num_transformer_feat_level
98
+ pixel_decoder_ = copy.deepcopy(pixel_decoder)
99
+ pixel_decoder_.update(in_channels=in_channels, feat_channels=feat_channels, out_channels=out_channels)
100
+ self.pixel_decoder = build_plugin_layer(pixel_decoder_)[1]
101
+ self.transformer_decoder = build_transformer_layer_sequence(transformer_decoder)
102
+ self.decoder_embed_dims = self.transformer_decoder.embed_dims
103
+
104
+ self.decoder_input_projs = ModuleList()
105
+ # from low resolution to high resolution
106
+ for _ in range(num_transformer_feat_level):
107
+ if self.decoder_embed_dims != feat_channels or enforce_decoder_input_project:
108
+ self.decoder_input_projs.append(Conv2d(feat_channels, self.decoder_embed_dims, kernel_size=1))
109
+ else:
110
+ self.decoder_input_projs.append(nn.Identity())
111
+ self.decoder_positional_encoding = build_positional_encoding(positional_encoding)
112
+ self.query_embed = nn.Embedding(self.num_queries, feat_channels)
113
+ self.query_feat = nn.Embedding(self.num_queries, feat_channels)
114
+ # from low resolution to high resolution
115
+ self.level_embed = nn.Embedding(self.num_transformer_feat_level, feat_channels)
116
+
117
+ self.cls_embed = nn.Linear(feat_channels, self.num_classes + 1)
118
+ self.mask_embed = nn.Sequential(
119
+ nn.Linear(feat_channels, feat_channels),
120
+ nn.ReLU(inplace=True),
121
+ nn.Linear(feat_channels, feat_channels),
122
+ nn.ReLU(inplace=True),
123
+ nn.Linear(feat_channels, out_channels),
124
+ )
125
+ self.conv_seg = None # fix a bug here (conv_seg is not used)
126
+
127
+ self.test_cfg = test_cfg
128
+ self.train_cfg = train_cfg
129
+ if train_cfg:
130
+ self.assigner = build_assigner(self.train_cfg.assigner)
131
+ self.sampler = build_sampler(self.train_cfg.sampler, context=self)
132
+ self.num_points = self.train_cfg.get("num_points", 12544)
133
+ self.oversample_ratio = self.train_cfg.get("oversample_ratio", 3.0)
134
+ self.importance_sample_ratio = self.train_cfg.get("importance_sample_ratio", 0.75)
135
+
136
+ self.class_weight = loss_cls.class_weight
137
+ self.loss_cls = build_loss(loss_cls)
138
+ self.loss_mask = build_loss(loss_mask)
139
+ self.loss_dice = build_loss(loss_dice)
140
+
141
+ def init_weights(self):
142
+ for m in self.decoder_input_projs:
143
+ if isinstance(m, Conv2d):
144
+ caffe2_xavier_init(m, bias=0)
145
+
146
+ self.pixel_decoder.init_weights()
147
+
148
+ for p in self.transformer_decoder.parameters():
149
+ if p.dim() > 1:
150
+ nn.init.xavier_normal_(p)
151
+
152
+ def get_targets(self, cls_scores_list, mask_preds_list, gt_labels_list, gt_masks_list, img_metas):
153
+ """Compute classification and mask targets for all images for a decoder
154
+ layer.
155
+
156
+ Args:
157
+ cls_scores_list (list[Tensor]): Mask score logits from a single
158
+ decoder layer for all images. Each with shape [num_queries,
159
+ cls_out_channels].
160
+ mask_preds_list (list[Tensor]): Mask logits from a single decoder
161
+ layer for all images. Each with shape [num_queries, h, w].
162
+ gt_labels_list (list[Tensor]): Ground truth class indices for all
163
+ images. Each with shape (n, ), n is the sum of number of stuff
164
+ type and number of instance in a image.
165
+ gt_masks_list (list[Tensor]): Ground truth mask for each image,
166
+ each with shape (n, h, w).
167
+ img_metas (list[dict]): List of image meta information.
168
+
169
+ Returns:
170
+ tuple[list[Tensor]]: a tuple containing the following targets.
171
+
172
+ - labels_list (list[Tensor]): Labels of all images.
173
+ Each with shape [num_queries, ].
174
+ - label_weights_list (list[Tensor]): Label weights of all
175
+ images.Each with shape [num_queries, ].
176
+ - mask_targets_list (list[Tensor]): Mask targets of all images.
177
+ Each with shape [num_queries, h, w].
178
+ - mask_weights_list (list[Tensor]): Mask weights of all images.
179
+ Each with shape [num_queries, ].
180
+ - num_total_pos (int): Number of positive samples in all
181
+ images.
182
+ - num_total_neg (int): Number of negative samples in all
183
+ images.
184
+ """
185
+ (
186
+ labels_list,
187
+ label_weights_list,
188
+ mask_targets_list,
189
+ mask_weights_list,
190
+ pos_inds_list,
191
+ neg_inds_list,
192
+ ) = multi_apply(
193
+ self._get_target_single, cls_scores_list, mask_preds_list, gt_labels_list, gt_masks_list, img_metas
194
+ )
195
+
196
+ num_total_pos = sum((inds.numel() for inds in pos_inds_list))
197
+ num_total_neg = sum((inds.numel() for inds in neg_inds_list))
198
+ return (labels_list, label_weights_list, mask_targets_list, mask_weights_list, num_total_pos, num_total_neg)
199
+
200
+ def _get_target_single(self, cls_score, mask_pred, gt_labels, gt_masks, img_metas):
201
+ """Compute classification and mask targets for one image.
202
+
203
+ Args:
204
+ cls_score (Tensor): Mask score logits from a single decoder layer
205
+ for one image. Shape (num_queries, cls_out_channels).
206
+ mask_pred (Tensor): Mask logits for a single decoder layer for one
207
+ image. Shape (num_queries, h, w).
208
+ gt_labels (Tensor): Ground truth class indices for one image with
209
+ shape (num_gts, ).
210
+ gt_masks (Tensor): Ground truth mask for each image, each with
211
+ shape (num_gts, h, w).
212
+ img_metas (dict): Image informtation.
213
+
214
+ Returns:
215
+ tuple[Tensor]: A tuple containing the following for one image.
216
+
217
+ - labels (Tensor): Labels of each image. \
218
+ shape (num_queries, ).
219
+ - label_weights (Tensor): Label weights of each image. \
220
+ shape (num_queries, ).
221
+ - mask_targets (Tensor): Mask targets of each image. \
222
+ shape (num_queries, h, w).
223
+ - mask_weights (Tensor): Mask weights of each image. \
224
+ shape (num_queries, ).
225
+ - pos_inds (Tensor): Sampled positive indices for each \
226
+ image.
227
+ - neg_inds (Tensor): Sampled negative indices for each \
228
+ image.
229
+ """
230
+ # sample points
231
+ num_queries = cls_score.shape[0]
232
+ num_gts = gt_labels.shape[0]
233
+
234
+ point_coords = torch.rand((1, self.num_points, 2), device=cls_score.device)
235
+ # shape (num_queries, num_points)
236
+ mask_points_pred = point_sample(mask_pred.unsqueeze(1), point_coords.repeat(num_queries, 1, 1)).squeeze(1)
237
+ # shape (num_gts, num_points)
238
+ gt_points_masks = point_sample(gt_masks.unsqueeze(1).float(), point_coords.repeat(num_gts, 1, 1)).squeeze(1)
239
+
240
+ # assign and sample
241
+ assign_result = self.assigner.assign(cls_score, mask_points_pred, gt_labels, gt_points_masks, img_metas)
242
+ sampling_result = self.sampler.sample(assign_result, mask_pred, gt_masks)
243
+ pos_inds = sampling_result.pos_inds
244
+ neg_inds = sampling_result.neg_inds
245
+
246
+ # label target
247
+ labels = gt_labels.new_full((self.num_queries,), self.num_classes, dtype=torch.long)
248
+ labels[pos_inds] = gt_labels[sampling_result.pos_assigned_gt_inds]
249
+ label_weights = gt_labels.new_ones((self.num_queries,))
250
+
251
+ # mask target
252
+ mask_targets = gt_masks[sampling_result.pos_assigned_gt_inds]
253
+ mask_weights = mask_pred.new_zeros((self.num_queries,))
254
+ mask_weights[pos_inds] = 1.0
255
+
256
+ return (labels, label_weights, mask_targets, mask_weights, pos_inds, neg_inds)
257
+
258
+ def loss_single(self, cls_scores, mask_preds, gt_labels_list, gt_masks_list, img_metas):
259
+ """Loss function for outputs from a single decoder layer.
260
+
261
+ Args:
262
+ cls_scores (Tensor): Mask score logits from a single decoder layer
263
+ for all images. Shape (batch_size, num_queries,
264
+ cls_out_channels). Note `cls_out_channels` should includes
265
+ background.
266
+ mask_preds (Tensor): Mask logits for a pixel decoder for all
267
+ images. Shape (batch_size, num_queries, h, w).
268
+ gt_labels_list (list[Tensor]): Ground truth class indices for each
269
+ image, each with shape (num_gts, ).
270
+ gt_masks_list (list[Tensor]): Ground truth mask for each image,
271
+ each with shape (num_gts, h, w).
272
+ img_metas (list[dict]): List of image meta information.
273
+
274
+ Returns:
275
+ tuple[Tensor]: Loss components for outputs from a single \
276
+ decoder layer.
277
+ """
278
+ num_imgs = cls_scores.size(0)
279
+ cls_scores_list = [cls_scores[i] for i in range(num_imgs)]
280
+ mask_preds_list = [mask_preds[i] for i in range(num_imgs)]
281
+ (
282
+ labels_list,
283
+ label_weights_list,
284
+ mask_targets_list,
285
+ mask_weights_list,
286
+ num_total_pos,
287
+ num_total_neg,
288
+ ) = self.get_targets(cls_scores_list, mask_preds_list, gt_labels_list, gt_masks_list, img_metas)
289
+ # shape (batch_size, num_queries)
290
+ labels = torch.stack(labels_list, dim=0)
291
+ # shape (batch_size, num_queries)
292
+ label_weights = torch.stack(label_weights_list, dim=0)
293
+ # shape (num_total_gts, h, w)
294
+ mask_targets = torch.cat(mask_targets_list, dim=0)
295
+ # shape (batch_size, num_queries)
296
+ mask_weights = torch.stack(mask_weights_list, dim=0)
297
+
298
+ # classfication loss
299
+ # shape (batch_size * num_queries, )
300
+ cls_scores = cls_scores.flatten(0, 1)
301
+ labels = labels.flatten(0, 1)
302
+ label_weights = label_weights.flatten(0, 1)
303
+
304
+ class_weight = cls_scores.new_tensor(self.class_weight)
305
+ loss_cls = self.loss_cls(cls_scores, labels, label_weights, avg_factor=class_weight[labels].sum())
306
+
307
+ num_total_masks = reduce_mean(cls_scores.new_tensor([num_total_pos]))
308
+ num_total_masks = max(num_total_masks, 1)
309
+
310
+ # extract positive ones
311
+ # shape (batch_size, num_queries, h, w) -> (num_total_gts, h, w)
312
+ mask_preds = mask_preds[mask_weights > 0]
313
+
314
+ if mask_targets.shape[0] == 0:
315
+ # zero match
316
+ loss_dice = mask_preds.sum()
317
+ loss_mask = mask_preds.sum()
318
+ return loss_cls, loss_mask, loss_dice
319
+
320
+ with torch.no_grad():
321
+ points_coords = get_uncertain_point_coords_with_randomness(
322
+ mask_preds.unsqueeze(1), None, self.num_points, self.oversample_ratio, self.importance_sample_ratio
323
+ )
324
+ # shape (num_total_gts, h, w) -> (num_total_gts, num_points)
325
+ mask_point_targets = point_sample(mask_targets.unsqueeze(1).float(), points_coords).squeeze(1)
326
+ # shape (num_queries, h, w) -> (num_queries, num_points)
327
+ mask_point_preds = point_sample(mask_preds.unsqueeze(1), points_coords).squeeze(1)
328
+
329
+ # dice loss
330
+ loss_dice = self.loss_dice(mask_point_preds, mask_point_targets, avg_factor=num_total_masks)
331
+
332
+ # mask loss
333
+ # shape (num_queries, num_points) -> (num_queries * num_points, )
334
+ mask_point_preds = mask_point_preds.reshape(-1, 1)
335
+ # shape (num_total_gts, num_points) -> (num_total_gts * num_points, )
336
+ mask_point_targets = mask_point_targets.reshape(-1)
337
+ loss_mask = self.loss_mask(mask_point_preds, mask_point_targets, avg_factor=num_total_masks * self.num_points)
338
+
339
+ return loss_cls, loss_mask, loss_dice
340
+
341
+ @force_fp32(apply_to=("all_cls_scores", "all_mask_preds"))
342
+ def loss(self, all_cls_scores, all_mask_preds, gt_labels_list, gt_masks_list, img_metas):
343
+ """Loss function.
344
+
345
+ Args:
346
+ all_cls_scores (Tensor): Classification scores for all decoder
347
+ layers with shape [num_decoder, batch_size, num_queries,
348
+ cls_out_channels].
349
+ all_mask_preds (Tensor): Mask scores for all decoder layers with
350
+ shape [num_decoder, batch_size, num_queries, h, w].
351
+ gt_labels_list (list[Tensor]): Ground truth class indices for each
352
+ image with shape (n, ). n is the sum of number of stuff type
353
+ and number of instance in a image.
354
+ gt_masks_list (list[Tensor]): Ground truth mask for each image with
355
+ shape (n, h, w).
356
+ img_metas (list[dict]): List of image meta information.
357
+
358
+ Returns:
359
+ dict[str, Tensor]: A dictionary of loss components.
360
+ """
361
+ num_dec_layers = len(all_cls_scores)
362
+ all_gt_labels_list = [gt_labels_list for _ in range(num_dec_layers)]
363
+ all_gt_masks_list = [gt_masks_list for _ in range(num_dec_layers)]
364
+ img_metas_list = [img_metas for _ in range(num_dec_layers)]
365
+ losses_cls, losses_mask, losses_dice = multi_apply(
366
+ self.loss_single, all_cls_scores, all_mask_preds, all_gt_labels_list, all_gt_masks_list, img_metas_list
367
+ )
368
+
369
+ loss_dict = dict()
370
+ # loss from the last decoder layer
371
+ loss_dict["loss_cls"] = losses_cls[-1]
372
+ loss_dict["loss_mask"] = losses_mask[-1]
373
+ loss_dict["loss_dice"] = losses_dice[-1]
374
+ # loss from other decoder layers
375
+ num_dec_layer = 0
376
+ for loss_cls_i, loss_mask_i, loss_dice_i in zip(losses_cls[:-1], losses_mask[:-1], losses_dice[:-1]):
377
+ loss_dict[f"d{num_dec_layer}.loss_cls"] = loss_cls_i
378
+ loss_dict[f"d{num_dec_layer}.loss_mask"] = loss_mask_i
379
+ loss_dict[f"d{num_dec_layer}.loss_dice"] = loss_dice_i
380
+ num_dec_layer += 1
381
+ return loss_dict
382
+
383
+ def forward_head(self, decoder_out, mask_feature, attn_mask_target_size):
384
+ """Forward for head part which is called after every decoder layer.
385
+
386
+ Args:
387
+ decoder_out (Tensor): in shape (num_queries, batch_size, c).
388
+ mask_feature (Tensor): in shape (batch_size, c, h, w).
389
+ attn_mask_target_size (tuple[int, int]): target attention
390
+ mask size.
391
+
392
+ Returns:
393
+ tuple: A tuple contain three elements.
394
+
395
+ - cls_pred (Tensor): Classification scores in shape \
396
+ (batch_size, num_queries, cls_out_channels). \
397
+ Note `cls_out_channels` should includes background.
398
+ - mask_pred (Tensor): Mask scores in shape \
399
+ (batch_size, num_queries,h, w).
400
+ - attn_mask (Tensor): Attention mask in shape \
401
+ (batch_size * num_heads, num_queries, h, w).
402
+ """
403
+ decoder_out = self.transformer_decoder.post_norm(decoder_out)
404
+ decoder_out = decoder_out.transpose(0, 1)
405
+ # shape (num_queries, batch_size, c)
406
+ cls_pred = self.cls_embed(decoder_out)
407
+ # shape (num_queries, batch_size, c)
408
+ mask_embed = self.mask_embed(decoder_out)
409
+ # shape (num_queries, batch_size, h, w)
410
+ mask_pred = torch.einsum("bqc,bchw->bqhw", mask_embed, mask_feature)
411
+ attn_mask = F.interpolate(mask_pred, attn_mask_target_size, mode="bilinear", align_corners=False)
412
+ # shape (num_queries, batch_size, h, w) ->
413
+ # (batch_size * num_head, num_queries, h, w)
414
+ attn_mask = attn_mask.flatten(2).unsqueeze(1).repeat((1, self.num_heads, 1, 1)).flatten(0, 1)
415
+ attn_mask = attn_mask.sigmoid() < 0.5
416
+ attn_mask = attn_mask.detach()
417
+
418
+ return cls_pred, mask_pred, attn_mask
419
+
420
+ def forward(self, feats, img_metas):
421
+ """Forward function.
422
+
423
+ Args:
424
+ feats (list[Tensor]): Multi scale Features from the
425
+ upstream network, each is a 4D-tensor.
426
+ img_metas (list[dict]): List of image information.
427
+
428
+ Returns:
429
+ tuple: A tuple contains two elements.
430
+
431
+ - cls_pred_list (list[Tensor)]: Classification logits \
432
+ for each decoder layer. Each is a 3D-tensor with shape \
433
+ (batch_size, num_queries, cls_out_channels). \
434
+ Note `cls_out_channels` should includes background.
435
+ - mask_pred_list (list[Tensor]): Mask logits for each \
436
+ decoder layer. Each with shape (batch_size, num_queries, \
437
+ h, w).
438
+ """
439
+ batch_size = len(img_metas)
440
+ mask_features, multi_scale_memorys = self.pixel_decoder(feats)
441
+ # multi_scale_memorys (from low resolution to high resolution)
442
+ decoder_inputs = []
443
+ decoder_positional_encodings = []
444
+ for i in range(self.num_transformer_feat_level):
445
+ decoder_input = self.decoder_input_projs[i](multi_scale_memorys[i])
446
+ # shape (batch_size, c, h, w) -> (h*w, batch_size, c)
447
+ decoder_input = decoder_input.flatten(2).permute(2, 0, 1)
448
+ level_embed = self.level_embed.weight[i].view(1, 1, -1)
449
+ decoder_input = decoder_input + level_embed
450
+ # shape (batch_size, c, h, w) -> (h*w, batch_size, c)
451
+ mask = decoder_input.new_zeros((batch_size,) + multi_scale_memorys[i].shape[-2:], dtype=torch.bool)
452
+ decoder_positional_encoding = self.decoder_positional_encoding(mask)
453
+ decoder_positional_encoding = decoder_positional_encoding.flatten(2).permute(2, 0, 1)
454
+ decoder_inputs.append(decoder_input)
455
+ decoder_positional_encodings.append(decoder_positional_encoding)
456
+ # shape (num_queries, c) -> (num_queries, batch_size, c)
457
+ query_feat = self.query_feat.weight.unsqueeze(1).repeat((1, batch_size, 1))
458
+ query_embed = self.query_embed.weight.unsqueeze(1).repeat((1, batch_size, 1))
459
+
460
+ cls_pred_list = []
461
+ mask_pred_list = []
462
+ cls_pred, mask_pred, attn_mask = self.forward_head(query_feat, mask_features, multi_scale_memorys[0].shape[-2:])
463
+ cls_pred_list.append(cls_pred)
464
+ mask_pred_list.append(mask_pred)
465
+
466
+ for i in range(self.num_transformer_decoder_layers):
467
+ level_idx = i % self.num_transformer_feat_level
468
+ # if a mask is all True(all background), then set it all False.
469
+ attn_mask[torch.where(attn_mask.sum(-1) == attn_mask.shape[-1])] = False
470
+
471
+ # cross_attn + self_attn
472
+ layer = self.transformer_decoder.layers[i]
473
+ attn_masks = [attn_mask, None]
474
+ query_feat = layer(
475
+ query=query_feat,
476
+ key=decoder_inputs[level_idx],
477
+ value=decoder_inputs[level_idx],
478
+ query_pos=query_embed,
479
+ key_pos=decoder_positional_encodings[level_idx],
480
+ attn_masks=attn_masks,
481
+ query_key_padding_mask=None,
482
+ # here we do not apply masking on padded region
483
+ key_padding_mask=None,
484
+ )
485
+ cls_pred, mask_pred, attn_mask = self.forward_head(
486
+ query_feat, mask_features, multi_scale_memorys[(i + 1) % self.num_transformer_feat_level].shape[-2:]
487
+ )
488
+
489
+ cls_pred_list.append(cls_pred)
490
+ mask_pred_list.append(mask_pred)
491
+
492
+ return cls_pred_list, mask_pred_list
493
+
494
+ def forward_train(self, x, img_metas, gt_semantic_seg, gt_labels, gt_masks):
495
+ """Forward function for training mode.
496
+
497
+ Args:
498
+ x (list[Tensor]): Multi-level features from the upstream network,
499
+ each is a 4D-tensor.
500
+ img_metas (list[Dict]): List of image information.
501
+ gt_semantic_seg (list[tensor]):Each element is the ground truth
502
+ of semantic segmentation with the shape (N, H, W).
503
+ train_cfg (dict): The training config, which not been used in
504
+ maskformer.
505
+ gt_labels (list[Tensor]): Each element is ground truth labels of
506
+ each box, shape (num_gts,).
507
+ gt_masks (list[BitmapMasks]): Each element is masks of instances
508
+ of a image, shape (num_gts, h, w).
509
+
510
+ Returns:
511
+ losses (dict[str, Tensor]): a dictionary of loss components
512
+ """
513
+
514
+ # forward
515
+ all_cls_scores, all_mask_preds = self(x, img_metas)
516
+
517
+ # loss
518
+ losses = self.loss(all_cls_scores, all_mask_preds, gt_labels, gt_masks, img_metas)
519
+
520
+ return losses
521
+
522
+ def forward_test(self, inputs, img_metas, test_cfg):
523
+ """Test segment without test-time aumengtation.
524
+
525
+ Only the output of last decoder layers was used.
526
+
527
+ Args:
528
+ inputs (list[Tensor]): Multi-level features from the
529
+ upstream network, each is a 4D-tensor.
530
+ img_metas (list[dict]): List of image information.
531
+ test_cfg (dict): Testing config.
532
+
533
+ Returns:
534
+ seg_mask (Tensor): Predicted semantic segmentation logits.
535
+ """
536
+ all_cls_scores, all_mask_preds = self(inputs, img_metas)
537
+ cls_score, mask_pred = all_cls_scores[-1], all_mask_preds[-1]
538
+ ori_h, ori_w, _ = img_metas[0]["ori_shape"]
539
+
540
+ # semantic inference
541
+ cls_score = F.softmax(cls_score, dim=-1)[..., :-1]
542
+ mask_pred = mask_pred.sigmoid()
543
+ seg_mask = torch.einsum("bqc,bqhw->bchw", cls_score, mask_pred)
544
+ return seg_mask
torch_hub/facebookresearch_dinov2_main/dinov2/eval/segmentation_m2f/models/utils/point_sample.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ from mmcv.ops import point_sample
8
+
9
+
10
+ def get_uncertainty(mask_pred, labels):
11
+ """Estimate uncertainty based on pred logits.
12
+
13
+ We estimate uncertainty as L1 distance between 0.0 and the logits
14
+ prediction in 'mask_pred' for the foreground class in `classes`.
15
+
16
+ Args:
17
+ mask_pred (Tensor): mask predication logits, shape (num_rois,
18
+ num_classes, mask_height, mask_width).
19
+
20
+ labels (list[Tensor]): Either predicted or ground truth label for
21
+ each predicted mask, of length num_rois.
22
+
23
+ Returns:
24
+ scores (Tensor): Uncertainty scores with the most uncertain
25
+ locations having the highest uncertainty score,
26
+ shape (num_rois, 1, mask_height, mask_width)
27
+ """
28
+ if mask_pred.shape[1] == 1:
29
+ gt_class_logits = mask_pred.clone()
30
+ else:
31
+ inds = torch.arange(mask_pred.shape[0], device=mask_pred.device)
32
+ gt_class_logits = mask_pred[inds, labels].unsqueeze(1)
33
+ return -torch.abs(gt_class_logits)
34
+
35
+
36
+ def get_uncertain_point_coords_with_randomness(
37
+ mask_pred, labels, num_points, oversample_ratio, importance_sample_ratio
38
+ ):
39
+ """Get ``num_points`` most uncertain points with random points during
40
+ train.
41
+
42
+ Sample points in [0, 1] x [0, 1] coordinate space based on their
43
+ uncertainty. The uncertainties are calculated for each point using
44
+ 'get_uncertainty()' function that takes point's logit prediction as
45
+ input.
46
+
47
+ Args:
48
+ mask_pred (Tensor): A tensor of shape (num_rois, num_classes,
49
+ mask_height, mask_width) for class-specific or class-agnostic
50
+ prediction.
51
+ labels (list): The ground truth class for each instance.
52
+ num_points (int): The number of points to sample.
53
+ oversample_ratio (int): Oversampling parameter.
54
+ importance_sample_ratio (float): Ratio of points that are sampled
55
+ via importnace sampling.
56
+
57
+ Returns:
58
+ point_coords (Tensor): A tensor of shape (num_rois, num_points, 2)
59
+ that contains the coordinates sampled points.
60
+ """
61
+ assert oversample_ratio >= 1
62
+ assert 0 <= importance_sample_ratio <= 1
63
+ batch_size = mask_pred.shape[0]
64
+ num_sampled = int(num_points * oversample_ratio)
65
+ point_coords = torch.rand(batch_size, num_sampled, 2, device=mask_pred.device)
66
+ point_logits = point_sample(mask_pred, point_coords)
67
+ # It is crucial to calculate uncertainty based on the sampled
68
+ # prediction value for the points. Calculating uncertainties of the
69
+ # coarse predictions first and sampling them for points leads to
70
+ # incorrect results. To illustrate this: assume uncertainty func(
71
+ # logits)=-abs(logits), a sampled point between two coarse
72
+ # predictions with -1 and 1 logits has 0 logits, and therefore 0
73
+ # uncertainty value. However, if we calculate uncertainties for the
74
+ # coarse predictions first, both will have -1 uncertainty,
75
+ # and sampled point will get -1 uncertainty.
76
+ point_uncertainties = get_uncertainty(point_logits, labels)
77
+ num_uncertain_points = int(importance_sample_ratio * num_points)
78
+ num_random_points = num_points - num_uncertain_points
79
+ idx = torch.topk(point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1]
80
+ shift = num_sampled * torch.arange(batch_size, dtype=torch.long, device=mask_pred.device)
81
+ idx += shift[:, None]
82
+ point_coords = point_coords.view(-1, 2)[idx.view(-1), :].view(batch_size, num_uncertain_points, 2)
83
+ if num_random_points > 0:
84
+ rand_roi_coords = torch.rand(batch_size, num_random_points, 2, device=mask_pred.device)
85
+ point_coords = torch.cat((point_coords, rand_roi_coords), dim=1)
86
+ return point_coords
torch_hub/facebookresearch_dinov2_main/dinov2/hub/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (167 Bytes). View file
 
torch_hub/facebookresearch_dinov2_main/dinov2/hub/__pycache__/backbones.cpython-310.pyc ADDED
Binary file (4.6 kB). View file
 
torch_hub/facebookresearch_dinov2_main/dinov2/hub/__pycache__/classifiers.cpython-310.pyc ADDED
Binary file (6.32 kB). View file
 
torch_hub/facebookresearch_dinov2_main/dinov2/hub/__pycache__/dinotxt.cpython-310.pyc ADDED
Binary file (2.78 kB). View file
 
torch_hub/facebookresearch_dinov2_main/dinov2/hub/__pycache__/utils.cpython-310.pyc ADDED
Binary file (1.78 kB). View file
 
torch_hub/facebookresearch_dinov2_main/dinov2/hub/depth/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ from .decode_heads import BNHead, DPTHead
7
+ from .encoder_decoder import DepthEncoderDecoder
torch_hub/facebookresearch_dinov2_main/dinov2/hub/depth/__pycache__/ops.cpython-310.pyc ADDED
Binary file (1.07 kB). View file
 
torch_hub/facebookresearch_dinov2_main/dinov2/hub/depth/decode_heads.py ADDED
@@ -0,0 +1,747 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import copy
7
+ from functools import partial
8
+ import math
9
+ import warnings
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+
14
+ from .ops import resize
15
+
16
+
17
+ # XXX: (Untested) replacement for mmcv.imdenormalize()
18
+ def _imdenormalize(img, mean, std, to_bgr=True):
19
+ import numpy as np
20
+
21
+ mean = mean.reshape(1, -1).astype(np.float64)
22
+ std = std.reshape(1, -1).astype(np.float64)
23
+ img = (img * std) + mean
24
+ if to_bgr:
25
+ img = img[::-1]
26
+ return img
27
+
28
+
29
+ class DepthBaseDecodeHead(nn.Module):
30
+ """Base class for BaseDecodeHead.
31
+
32
+ Args:
33
+ in_channels (List): Input channels.
34
+ channels (int): Channels after modules, before conv_depth.
35
+ conv_layer (nn.Module): Conv layers. Default: None.
36
+ act_layer (nn.Module): Activation layers. Default: nn.ReLU.
37
+ loss_decode (dict): Config of decode loss.
38
+ Default: ().
39
+ sampler (dict|None): The config of depth map sampler.
40
+ Default: None.
41
+ align_corners (bool): align_corners argument of F.interpolate.
42
+ Default: False.
43
+ min_depth (int): Min depth in dataset setting.
44
+ Default: 1e-3.
45
+ max_depth (int): Max depth in dataset setting.
46
+ Default: None.
47
+ norm_layer (dict|None): Norm layers.
48
+ Default: None.
49
+ classify (bool): Whether predict depth in a cls.-reg. manner.
50
+ Default: False.
51
+ n_bins (int): The number of bins used in cls. step.
52
+ Default: 256.
53
+ bins_strategy (str): The discrete strategy used in cls. step.
54
+ Default: 'UD'.
55
+ norm_strategy (str): The norm strategy on cls. probability
56
+ distribution. Default: 'linear'
57
+ scale_up (str): Whether predict depth in a scale-up manner.
58
+ Default: False.
59
+ """
60
+
61
+ def __init__(
62
+ self,
63
+ in_channels,
64
+ conv_layer=None,
65
+ act_layer=nn.ReLU,
66
+ channels=96,
67
+ loss_decode=(),
68
+ sampler=None,
69
+ align_corners=False,
70
+ min_depth=1e-3,
71
+ max_depth=None,
72
+ norm_layer=None,
73
+ classify=False,
74
+ n_bins=256,
75
+ bins_strategy="UD",
76
+ norm_strategy="linear",
77
+ scale_up=False,
78
+ ):
79
+ super(DepthBaseDecodeHead, self).__init__()
80
+
81
+ self.in_channels = in_channels
82
+ self.channels = channels
83
+ self.conf_layer = conv_layer
84
+ self.act_layer = act_layer
85
+ self.loss_decode = loss_decode
86
+ self.align_corners = align_corners
87
+ self.min_depth = min_depth
88
+ self.max_depth = max_depth
89
+ self.norm_layer = norm_layer
90
+ self.classify = classify
91
+ self.n_bins = n_bins
92
+ self.scale_up = scale_up
93
+
94
+ if self.classify:
95
+ assert bins_strategy in ["UD", "SID"], "Support bins_strategy: UD, SID"
96
+ assert norm_strategy in ["linear", "softmax", "sigmoid"], "Support norm_strategy: linear, softmax, sigmoid"
97
+
98
+ self.bins_strategy = bins_strategy
99
+ self.norm_strategy = norm_strategy
100
+ self.softmax = nn.Softmax(dim=1)
101
+ self.conv_depth = nn.Conv2d(channels, n_bins, kernel_size=3, padding=1, stride=1)
102
+ else:
103
+ self.conv_depth = nn.Conv2d(channels, 1, kernel_size=3, padding=1, stride=1)
104
+
105
+ self.relu = nn.ReLU()
106
+ self.sigmoid = nn.Sigmoid()
107
+
108
+ def forward(self, inputs, img_metas):
109
+ """Placeholder of forward function."""
110
+ pass
111
+
112
+ def forward_train(self, img, inputs, img_metas, depth_gt):
113
+ """Forward function for training.
114
+ Args:
115
+ inputs (list[Tensor]): List of multi-level img features.
116
+ img_metas (list[dict]): List of image info dict where each dict
117
+ has: 'img_shape', 'scale_factor', 'flip', and may also contain
118
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
119
+ For details on the values of these keys see
120
+ `depth/datasets/pipelines/formatting.py:Collect`.
121
+ depth_gt (Tensor): GT depth
122
+
123
+ Returns:
124
+ dict[str, Tensor]: a dictionary of loss components
125
+ """
126
+ depth_pred = self.forward(inputs, img_metas)
127
+ losses = self.losses(depth_pred, depth_gt)
128
+
129
+ log_imgs = self.log_images(img[0], depth_pred[0], depth_gt[0], img_metas[0])
130
+ losses.update(**log_imgs)
131
+
132
+ return losses
133
+
134
+ def forward_test(self, inputs, img_metas):
135
+ """Forward function for testing.
136
+ Args:
137
+ inputs (list[Tensor]): List of multi-level img features.
138
+ img_metas (list[dict]): List of image info dict where each dict
139
+ has: 'img_shape', 'scale_factor', 'flip', and may also contain
140
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
141
+ For details on the values of these keys see
142
+ `depth/datasets/pipelines/formatting.py:Collect`.
143
+
144
+ Returns:
145
+ Tensor: Output depth map.
146
+ """
147
+ return self.forward(inputs, img_metas)
148
+
149
+ def depth_pred(self, feat):
150
+ """Prediction each pixel."""
151
+ if self.classify:
152
+ logit = self.conv_depth(feat)
153
+
154
+ if self.bins_strategy == "UD":
155
+ bins = torch.linspace(self.min_depth, self.max_depth, self.n_bins, device=feat.device)
156
+ elif self.bins_strategy == "SID":
157
+ bins = torch.logspace(self.min_depth, self.max_depth, self.n_bins, device=feat.device)
158
+
159
+ # following Adabins, default linear
160
+ if self.norm_strategy == "linear":
161
+ logit = torch.relu(logit)
162
+ eps = 0.1
163
+ logit = logit + eps
164
+ logit = logit / logit.sum(dim=1, keepdim=True)
165
+ elif self.norm_strategy == "softmax":
166
+ logit = torch.softmax(logit, dim=1)
167
+ elif self.norm_strategy == "sigmoid":
168
+ logit = torch.sigmoid(logit)
169
+ logit = logit / logit.sum(dim=1, keepdim=True)
170
+
171
+ output = torch.einsum("ikmn,k->imn", [logit, bins]).unsqueeze(dim=1)
172
+
173
+ else:
174
+ if self.scale_up:
175
+ output = self.sigmoid(self.conv_depth(feat)) * self.max_depth
176
+ else:
177
+ output = self.relu(self.conv_depth(feat)) + self.min_depth
178
+ return output
179
+
180
+ def losses(self, depth_pred, depth_gt):
181
+ """Compute depth loss."""
182
+ loss = dict()
183
+ depth_pred = resize(
184
+ input=depth_pred, size=depth_gt.shape[2:], mode="bilinear", align_corners=self.align_corners, warning=False
185
+ )
186
+ if not isinstance(self.loss_decode, nn.ModuleList):
187
+ losses_decode = [self.loss_decode]
188
+ else:
189
+ losses_decode = self.loss_decode
190
+ for loss_decode in losses_decode:
191
+ if loss_decode.loss_name not in loss:
192
+ loss[loss_decode.loss_name] = loss_decode(depth_pred, depth_gt)
193
+ else:
194
+ loss[loss_decode.loss_name] += loss_decode(depth_pred, depth_gt)
195
+ return loss
196
+
197
+ def log_images(self, img_path, depth_pred, depth_gt, img_meta):
198
+ import numpy as np
199
+
200
+ show_img = copy.deepcopy(img_path.detach().cpu().permute(1, 2, 0))
201
+ show_img = show_img.numpy().astype(np.float32)
202
+ show_img = _imdenormalize(
203
+ show_img,
204
+ img_meta["img_norm_cfg"]["mean"],
205
+ img_meta["img_norm_cfg"]["std"],
206
+ img_meta["img_norm_cfg"]["to_rgb"],
207
+ )
208
+ show_img = np.clip(show_img, 0, 255)
209
+ show_img = show_img.astype(np.uint8)
210
+ show_img = show_img[:, :, ::-1]
211
+ show_img = show_img.transpose(0, 2, 1)
212
+ show_img = show_img.transpose(1, 0, 2)
213
+
214
+ depth_pred = depth_pred / torch.max(depth_pred)
215
+ depth_gt = depth_gt / torch.max(depth_gt)
216
+
217
+ depth_pred_color = copy.deepcopy(depth_pred.detach().cpu())
218
+ depth_gt_color = copy.deepcopy(depth_gt.detach().cpu())
219
+
220
+ return {"img_rgb": show_img, "img_depth_pred": depth_pred_color, "img_depth_gt": depth_gt_color}
221
+
222
+
223
+ class BNHead(DepthBaseDecodeHead):
224
+ """Just a batchnorm."""
225
+
226
+ def __init__(self, input_transform="resize_concat", in_index=(0, 1, 2, 3), upsample=1, **kwargs):
227
+ super().__init__(**kwargs)
228
+ self.input_transform = input_transform
229
+ self.in_index = in_index
230
+ self.upsample = upsample
231
+ # self.bn = nn.SyncBatchNorm(self.in_channels)
232
+ if self.classify:
233
+ self.conv_depth = nn.Conv2d(self.channels, self.n_bins, kernel_size=1, padding=0, stride=1)
234
+ else:
235
+ self.conv_depth = nn.Conv2d(self.channels, 1, kernel_size=1, padding=0, stride=1)
236
+
237
+ def _transform_inputs(self, inputs):
238
+ """Transform inputs for decoder.
239
+ Args:
240
+ inputs (list[Tensor]): List of multi-level img features.
241
+ Returns:
242
+ Tensor: The transformed inputs
243
+ """
244
+
245
+ if "concat" in self.input_transform:
246
+ inputs = [inputs[i] for i in self.in_index]
247
+ if "resize" in self.input_transform:
248
+ inputs = [
249
+ resize(
250
+ input=x,
251
+ size=[s * self.upsample for s in inputs[0].shape[2:]],
252
+ mode="bilinear",
253
+ align_corners=self.align_corners,
254
+ )
255
+ for x in inputs
256
+ ]
257
+ inputs = torch.cat(inputs, dim=1)
258
+ elif self.input_transform == "multiple_select":
259
+ inputs = [inputs[i] for i in self.in_index]
260
+ else:
261
+ inputs = inputs[self.in_index]
262
+
263
+ return inputs
264
+
265
+ def _forward_feature(self, inputs, img_metas=None, **kwargs):
266
+ """Forward function for feature maps before classifying each pixel with
267
+ ``self.cls_seg`` fc.
268
+ Args:
269
+ inputs (list[Tensor]): List of multi-level img features.
270
+ Returns:
271
+ feats (Tensor): A tensor of shape (batch_size, self.channels,
272
+ H, W) which is feature map for last layer of decoder head.
273
+ """
274
+ # accept lists (for cls token)
275
+ inputs = list(inputs)
276
+ for i, x in enumerate(inputs):
277
+ if len(x) == 2:
278
+ x, cls_token = x[0], x[1]
279
+ if len(x.shape) == 2:
280
+ x = x[:, :, None, None]
281
+ cls_token = cls_token[:, :, None, None].expand_as(x)
282
+ inputs[i] = torch.cat((x, cls_token), 1)
283
+ else:
284
+ x = x[0]
285
+ if len(x.shape) == 2:
286
+ x = x[:, :, None, None]
287
+ inputs[i] = x
288
+ x = self._transform_inputs(inputs)
289
+ # feats = self.bn(x)
290
+ return x
291
+
292
+ def forward(self, inputs, img_metas=None, **kwargs):
293
+ """Forward function."""
294
+ output = self._forward_feature(inputs, img_metas=img_metas, **kwargs)
295
+ output = self.depth_pred(output)
296
+ return output
297
+
298
+
299
+ class ConvModule(nn.Module):
300
+ """A conv block that bundles conv/norm/activation layers.
301
+
302
+ This block simplifies the usage of convolution layers, which are commonly
303
+ used with a norm layer (e.g., BatchNorm) and activation layer (e.g., ReLU).
304
+ It is based upon three build methods: `build_conv_layer()`,
305
+ `build_norm_layer()` and `build_activation_layer()`.
306
+
307
+ Besides, we add some additional features in this module.
308
+ 1. Automatically set `bias` of the conv layer.
309
+ 2. Spectral norm is supported.
310
+ 3. More padding modes are supported. Before PyTorch 1.5, nn.Conv2d only
311
+ supports zero and circular padding, and we add "reflect" padding mode.
312
+
313
+ Args:
314
+ in_channels (int): Number of channels in the input feature map.
315
+ Same as that in ``nn._ConvNd``.
316
+ out_channels (int): Number of channels produced by the convolution.
317
+ Same as that in ``nn._ConvNd``.
318
+ kernel_size (int | tuple[int]): Size of the convolving kernel.
319
+ Same as that in ``nn._ConvNd``.
320
+ stride (int | tuple[int]): Stride of the convolution.
321
+ Same as that in ``nn._ConvNd``.
322
+ padding (int | tuple[int]): Zero-padding added to both sides of
323
+ the input. Same as that in ``nn._ConvNd``.
324
+ dilation (int | tuple[int]): Spacing between kernel elements.
325
+ Same as that in ``nn._ConvNd``.
326
+ groups (int): Number of blocked connections from input channels to
327
+ output channels. Same as that in ``nn._ConvNd``.
328
+ bias (bool | str): If specified as `auto`, it will be decided by the
329
+ norm_layer. Bias will be set as True if `norm_layer` is None, otherwise
330
+ False. Default: "auto".
331
+ conv_layer (nn.Module): Convolution layer. Default: None,
332
+ which means using conv2d.
333
+ norm_layer (nn.Module): Normalization layer. Default: None.
334
+ act_layer (nn.Module): Activation layer. Default: nn.ReLU.
335
+ inplace (bool): Whether to use inplace mode for activation.
336
+ Default: True.
337
+ with_spectral_norm (bool): Whether use spectral norm in conv module.
338
+ Default: False.
339
+ padding_mode (str): If the `padding_mode` has not been supported by
340
+ current `Conv2d` in PyTorch, we will use our own padding layer
341
+ instead. Currently, we support ['zeros', 'circular'] with official
342
+ implementation and ['reflect'] with our own implementation.
343
+ Default: 'zeros'.
344
+ order (tuple[str]): The order of conv/norm/activation layers. It is a
345
+ sequence of "conv", "norm" and "act". Common examples are
346
+ ("conv", "norm", "act") and ("act", "conv", "norm").
347
+ Default: ('conv', 'norm', 'act').
348
+ """
349
+
350
+ _abbr_ = "conv_block"
351
+
352
+ def __init__(
353
+ self,
354
+ in_channels,
355
+ out_channels,
356
+ kernel_size,
357
+ stride=1,
358
+ padding=0,
359
+ dilation=1,
360
+ groups=1,
361
+ bias="auto",
362
+ conv_layer=nn.Conv2d,
363
+ norm_layer=None,
364
+ act_layer=nn.ReLU,
365
+ inplace=True,
366
+ with_spectral_norm=False,
367
+ padding_mode="zeros",
368
+ order=("conv", "norm", "act"),
369
+ ):
370
+ super(ConvModule, self).__init__()
371
+ official_padding_mode = ["zeros", "circular"]
372
+ self.conv_layer = conv_layer
373
+ self.norm_layer = norm_layer
374
+ self.act_layer = act_layer
375
+ self.inplace = inplace
376
+ self.with_spectral_norm = with_spectral_norm
377
+ self.with_explicit_padding = padding_mode not in official_padding_mode
378
+ self.order = order
379
+ assert isinstance(self.order, tuple) and len(self.order) == 3
380
+ assert set(order) == set(["conv", "norm", "act"])
381
+
382
+ self.with_norm = norm_layer is not None
383
+ self.with_activation = act_layer is not None
384
+ # if the conv layer is before a norm layer, bias is unnecessary.
385
+ if bias == "auto":
386
+ bias = not self.with_norm
387
+ self.with_bias = bias
388
+
389
+ if self.with_explicit_padding:
390
+ if padding_mode == "zeros":
391
+ padding_layer = nn.ZeroPad2d
392
+ else:
393
+ raise AssertionError(f"Unsupported padding mode: {padding_mode}")
394
+ self.pad = padding_layer(padding)
395
+
396
+ # reset padding to 0 for conv module
397
+ conv_padding = 0 if self.with_explicit_padding else padding
398
+ # build convolution layer
399
+ self.conv = self.conv_layer(
400
+ in_channels,
401
+ out_channels,
402
+ kernel_size,
403
+ stride=stride,
404
+ padding=conv_padding,
405
+ dilation=dilation,
406
+ groups=groups,
407
+ bias=bias,
408
+ )
409
+ # export the attributes of self.conv to a higher level for convenience
410
+ self.in_channels = self.conv.in_channels
411
+ self.out_channels = self.conv.out_channels
412
+ self.kernel_size = self.conv.kernel_size
413
+ self.stride = self.conv.stride
414
+ self.padding = padding
415
+ self.dilation = self.conv.dilation
416
+ self.transposed = self.conv.transposed
417
+ self.output_padding = self.conv.output_padding
418
+ self.groups = self.conv.groups
419
+
420
+ if self.with_spectral_norm:
421
+ self.conv = nn.utils.spectral_norm(self.conv)
422
+
423
+ # build normalization layers
424
+ if self.with_norm:
425
+ # norm layer is after conv layer
426
+ if order.index("norm") > order.index("conv"):
427
+ norm_channels = out_channels
428
+ else:
429
+ norm_channels = in_channels
430
+ norm = partial(norm_layer, num_features=norm_channels)
431
+ self.add_module("norm", norm)
432
+ if self.with_bias:
433
+ from torch.nnModules.batchnorm import _BatchNorm
434
+ from torch.nnModules.instancenorm import _InstanceNorm
435
+
436
+ if isinstance(norm, (_BatchNorm, _InstanceNorm)):
437
+ warnings.warn("Unnecessary conv bias before batch/instance norm")
438
+ else:
439
+ self.norm_name = None
440
+
441
+ # build activation layer
442
+ if self.with_activation:
443
+ # nn.Tanh has no 'inplace' argument
444
+ # (nn.Tanh, nn.PReLU, nn.Sigmoid, nn.HSigmoid, nn.Swish, nn.GELU)
445
+ if not isinstance(act_layer, (nn.Tanh, nn.PReLU, nn.Sigmoid, nn.GELU)):
446
+ act_layer = partial(act_layer, inplace=inplace)
447
+ self.activate = act_layer()
448
+
449
+ # Use msra init by default
450
+ self.init_weights()
451
+
452
+ @property
453
+ def norm(self):
454
+ if self.norm_name:
455
+ return getattr(self, self.norm_name)
456
+ else:
457
+ return None
458
+
459
+ def init_weights(self):
460
+ # 1. It is mainly for customized conv layers with their own
461
+ # initialization manners by calling their own ``init_weights()``,
462
+ # and we do not want ConvModule to override the initialization.
463
+ # 2. For customized conv layers without their own initialization
464
+ # manners (that is, they don't have their own ``init_weights()``)
465
+ # and PyTorch's conv layers, they will be initialized by
466
+ # this method with default ``kaiming_init``.
467
+ # Note: For PyTorch's conv layers, they will be overwritten by our
468
+ # initialization implementation using default ``kaiming_init``.
469
+ if not hasattr(self.conv, "init_weights"):
470
+ if self.with_activation and isinstance(self.act_layer, nn.LeakyReLU):
471
+ nonlinearity = "leaky_relu"
472
+ a = 0.01 # XXX: default negative_slope
473
+ else:
474
+ nonlinearity = "relu"
475
+ a = 0
476
+ if hasattr(self.conv, "weight") and self.conv.weight is not None:
477
+ nn.init.kaiming_normal_(self.conv.weight, a=a, mode="fan_out", nonlinearity=nonlinearity)
478
+ if hasattr(self.conv, "bias") and self.conv.bias is not None:
479
+ nn.init.constant_(self.conv.bias, 0)
480
+ if self.with_norm:
481
+ if hasattr(self.norm, "weight") and self.norm.weight is not None:
482
+ nn.init.constant_(self.norm.weight, 1)
483
+ if hasattr(self.norm, "bias") and self.norm.bias is not None:
484
+ nn.init.constant_(self.norm.bias, 0)
485
+
486
+ def forward(self, x, activate=True, norm=True):
487
+ for layer in self.order:
488
+ if layer == "conv":
489
+ if self.with_explicit_padding:
490
+ x = self.pad(x)
491
+ x = self.conv(x)
492
+ elif layer == "norm" and norm and self.with_norm:
493
+ x = self.norm(x)
494
+ elif layer == "act" and activate and self.with_activation:
495
+ x = self.activate(x)
496
+ return x
497
+
498
+
499
+ class Interpolate(nn.Module):
500
+ def __init__(self, scale_factor, mode, align_corners=False):
501
+ super(Interpolate, self).__init__()
502
+ self.interp = nn.functional.interpolate
503
+ self.scale_factor = scale_factor
504
+ self.mode = mode
505
+ self.align_corners = align_corners
506
+
507
+ def forward(self, x):
508
+ x = self.interp(x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners)
509
+ return x
510
+
511
+
512
+ class HeadDepth(nn.Module):
513
+ def __init__(self, features):
514
+ super(HeadDepth, self).__init__()
515
+ self.head = nn.Sequential(
516
+ nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
517
+ Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
518
+ nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
519
+ nn.ReLU(),
520
+ nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
521
+ )
522
+
523
+ def forward(self, x):
524
+ x = self.head(x)
525
+ return x
526
+
527
+
528
+ class ReassembleBlocks(nn.Module):
529
+ """ViTPostProcessBlock, process cls_token in ViT backbone output and
530
+ rearrange the feature vector to feature map.
531
+ Args:
532
+ in_channels (int): ViT feature channels. Default: 768.
533
+ out_channels (List): output channels of each stage.
534
+ Default: [96, 192, 384, 768].
535
+ readout_type (str): Type of readout operation. Default: 'ignore'.
536
+ patch_size (int): The patch size. Default: 16.
537
+ """
538
+
539
+ def __init__(self, in_channels=768, out_channels=[96, 192, 384, 768], readout_type="ignore", patch_size=16):
540
+ super(ReassembleBlocks, self).__init__()
541
+
542
+ assert readout_type in ["ignore", "add", "project"]
543
+ self.readout_type = readout_type
544
+ self.patch_size = patch_size
545
+
546
+ self.projects = nn.ModuleList(
547
+ [
548
+ ConvModule(
549
+ in_channels=in_channels,
550
+ out_channels=out_channel,
551
+ kernel_size=1,
552
+ act_layer=None,
553
+ )
554
+ for out_channel in out_channels
555
+ ]
556
+ )
557
+
558
+ self.resize_layers = nn.ModuleList(
559
+ [
560
+ nn.ConvTranspose2d(
561
+ in_channels=out_channels[0], out_channels=out_channels[0], kernel_size=4, stride=4, padding=0
562
+ ),
563
+ nn.ConvTranspose2d(
564
+ in_channels=out_channels[1], out_channels=out_channels[1], kernel_size=2, stride=2, padding=0
565
+ ),
566
+ nn.Identity(),
567
+ nn.Conv2d(
568
+ in_channels=out_channels[3], out_channels=out_channels[3], kernel_size=3, stride=2, padding=1
569
+ ),
570
+ ]
571
+ )
572
+ if self.readout_type == "project":
573
+ self.readout_projects = nn.ModuleList()
574
+ for _ in range(len(self.projects)):
575
+ self.readout_projects.append(nn.Sequential(nn.Linear(2 * in_channels, in_channels), nn.GELU()))
576
+
577
+ def forward(self, inputs):
578
+ assert isinstance(inputs, list)
579
+ out = []
580
+ for i, x in enumerate(inputs):
581
+ assert len(x) == 2
582
+ x, cls_token = x[0], x[1]
583
+ feature_shape = x.shape
584
+ if self.readout_type == "project":
585
+ x = x.flatten(2).permute((0, 2, 1))
586
+ readout = cls_token.unsqueeze(1).expand_as(x)
587
+ x = self.readout_projects[i](torch.cat((x, readout), -1))
588
+ x = x.permute(0, 2, 1).reshape(feature_shape)
589
+ elif self.readout_type == "add":
590
+ x = x.flatten(2) + cls_token.unsqueeze(-1)
591
+ x = x.reshape(feature_shape)
592
+ else:
593
+ pass
594
+ x = self.projects[i](x)
595
+ x = self.resize_layers[i](x)
596
+ out.append(x)
597
+ return out
598
+
599
+
600
+ class PreActResidualConvUnit(nn.Module):
601
+ """ResidualConvUnit, pre-activate residual unit.
602
+ Args:
603
+ in_channels (int): number of channels in the input feature map.
604
+ act_layer (nn.Module): activation layer.
605
+ norm_layer (nn.Module): norm layer.
606
+ stride (int): stride of the first block. Default: 1
607
+ dilation (int): dilation rate for convs layers. Default: 1.
608
+ """
609
+
610
+ def __init__(self, in_channels, act_layer, norm_layer, stride=1, dilation=1):
611
+ super(PreActResidualConvUnit, self).__init__()
612
+
613
+ self.conv1 = ConvModule(
614
+ in_channels,
615
+ in_channels,
616
+ 3,
617
+ stride=stride,
618
+ padding=dilation,
619
+ dilation=dilation,
620
+ norm_layer=norm_layer,
621
+ act_layer=act_layer,
622
+ bias=False,
623
+ order=("act", "conv", "norm"),
624
+ )
625
+
626
+ self.conv2 = ConvModule(
627
+ in_channels,
628
+ in_channels,
629
+ 3,
630
+ padding=1,
631
+ norm_layer=norm_layer,
632
+ act_layer=act_layer,
633
+ bias=False,
634
+ order=("act", "conv", "norm"),
635
+ )
636
+
637
+ def forward(self, inputs):
638
+ inputs_ = inputs.clone()
639
+ x = self.conv1(inputs)
640
+ x = self.conv2(x)
641
+ return x + inputs_
642
+
643
+
644
+ class FeatureFusionBlock(nn.Module):
645
+ """FeatureFusionBlock, merge feature map from different stages.
646
+ Args:
647
+ in_channels (int): Input channels.
648
+ act_layer (nn.Module): activation layer for ResidualConvUnit.
649
+ norm_layer (nn.Module): normalization layer.
650
+ expand (bool): Whether expand the channels in post process block.
651
+ Default: False.
652
+ align_corners (bool): align_corner setting for bilinear upsample.
653
+ Default: True.
654
+ """
655
+
656
+ def __init__(self, in_channels, act_layer, norm_layer, expand=False, align_corners=True):
657
+ super(FeatureFusionBlock, self).__init__()
658
+
659
+ self.in_channels = in_channels
660
+ self.expand = expand
661
+ self.align_corners = align_corners
662
+
663
+ self.out_channels = in_channels
664
+ if self.expand:
665
+ self.out_channels = in_channels // 2
666
+
667
+ self.project = ConvModule(self.in_channels, self.out_channels, kernel_size=1, act_layer=None, bias=True)
668
+
669
+ self.res_conv_unit1 = PreActResidualConvUnit(
670
+ in_channels=self.in_channels, act_layer=act_layer, norm_layer=norm_layer
671
+ )
672
+ self.res_conv_unit2 = PreActResidualConvUnit(
673
+ in_channels=self.in_channels, act_layer=act_layer, norm_layer=norm_layer
674
+ )
675
+
676
+ def forward(self, *inputs):
677
+ x = inputs[0]
678
+ if len(inputs) == 2:
679
+ if x.shape != inputs[1].shape:
680
+ res = resize(inputs[1], size=(x.shape[2], x.shape[3]), mode="bilinear", align_corners=False)
681
+ else:
682
+ res = inputs[1]
683
+ x = x + self.res_conv_unit1(res)
684
+ x = self.res_conv_unit2(x)
685
+ x = resize(x, scale_factor=2, mode="bilinear", align_corners=self.align_corners)
686
+ x = self.project(x)
687
+ return x
688
+
689
+
690
+ class DPTHead(DepthBaseDecodeHead):
691
+ """Vision Transformers for Dense Prediction.
692
+ This head is implemented of `DPT <https://arxiv.org/abs/2103.13413>`_.
693
+ Args:
694
+ embed_dims (int): The embed dimension of the ViT backbone.
695
+ Default: 768.
696
+ post_process_channels (List): Out channels of post process conv
697
+ layers. Default: [96, 192, 384, 768].
698
+ readout_type (str): Type of readout operation. Default: 'ignore'.
699
+ patch_size (int): The patch size. Default: 16.
700
+ expand_channels (bool): Whether expand the channels in post process
701
+ block. Default: False.
702
+ """
703
+
704
+ def __init__(
705
+ self,
706
+ embed_dims=768,
707
+ post_process_channels=[96, 192, 384, 768],
708
+ readout_type="ignore",
709
+ patch_size=16,
710
+ expand_channels=False,
711
+ **kwargs,
712
+ ):
713
+ super(DPTHead, self).__init__(**kwargs)
714
+
715
+ self.in_channels = self.in_channels
716
+ self.expand_channels = expand_channels
717
+ self.reassemble_blocks = ReassembleBlocks(embed_dims, post_process_channels, readout_type, patch_size)
718
+
719
+ self.post_process_channels = [
720
+ channel * math.pow(2, i) if expand_channels else channel for i, channel in enumerate(post_process_channels)
721
+ ]
722
+ self.convs = nn.ModuleList()
723
+ for channel in self.post_process_channels:
724
+ self.convs.append(ConvModule(channel, self.channels, kernel_size=3, padding=1, act_layer=None, bias=False))
725
+ self.fusion_blocks = nn.ModuleList()
726
+ for _ in range(len(self.convs)):
727
+ self.fusion_blocks.append(FeatureFusionBlock(self.channels, self.act_layer, self.norm_layer))
728
+ self.fusion_blocks[0].res_conv_unit1 = None
729
+ self.project = ConvModule(self.channels, self.channels, kernel_size=3, padding=1, norm_layer=self.norm_layer)
730
+ self.num_fusion_blocks = len(self.fusion_blocks)
731
+ self.num_reassemble_blocks = len(self.reassemble_blocks.resize_layers)
732
+ self.num_post_process_channels = len(self.post_process_channels)
733
+ assert self.num_fusion_blocks == self.num_reassemble_blocks
734
+ assert self.num_reassemble_blocks == self.num_post_process_channels
735
+ self.conv_depth = HeadDepth(self.channels)
736
+
737
+ def forward(self, inputs, img_metas):
738
+ assert len(inputs) == self.num_reassemble_blocks
739
+ x = [inp for inp in inputs]
740
+ x = self.reassemble_blocks(x)
741
+ x = [self.convs[i](feature) for i, feature in enumerate(x)]
742
+ out = self.fusion_blocks[0](x[-1])
743
+ for i in range(1, len(self.fusion_blocks)):
744
+ out = self.fusion_blocks[i](out, x[-(i + 1)])
745
+ out = self.project(out)
746
+ out = self.depth_pred(out)
747
+ return out
torch_hub/facebookresearch_dinov2_main/dinov2/hub/depth/encoder_decoder.py ADDED
@@ -0,0 +1,351 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ from collections import OrderedDict
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+
12
+ from .ops import resize
13
+
14
+
15
+ def add_prefix(inputs, prefix):
16
+ """Add prefix for dict.
17
+
18
+ Args:
19
+ inputs (dict): The input dict with str keys.
20
+ prefix (str): The prefix to add.
21
+
22
+ Returns:
23
+
24
+ dict: The dict with keys updated with ``prefix``.
25
+ """
26
+
27
+ outputs = dict()
28
+ for name, value in inputs.items():
29
+ outputs[f"{prefix}.{name}"] = value
30
+
31
+ return outputs
32
+
33
+
34
+ class DepthEncoderDecoder(nn.Module):
35
+ """Encoder Decoder depther.
36
+
37
+ EncoderDecoder typically consists of backbone and decode_head.
38
+ """
39
+
40
+ def __init__(self, backbone, decode_head):
41
+ super(DepthEncoderDecoder, self).__init__()
42
+
43
+ self.backbone = backbone
44
+ self.decode_head = decode_head
45
+ self.align_corners = self.decode_head.align_corners
46
+
47
+ def extract_feat(self, img):
48
+ """Extract features from images."""
49
+ return self.backbone(img)
50
+
51
+ def encode_decode(self, img, img_metas, rescale=True, size=None):
52
+ """Encode images with backbone and decode into a depth estimation
53
+ map of the same size as input."""
54
+ x = self.extract_feat(img)
55
+ out = self._decode_head_forward_test(x, img_metas)
56
+ # crop the pred depth to the certain range.
57
+ out = torch.clamp(out, min=self.decode_head.min_depth, max=self.decode_head.max_depth)
58
+ if rescale:
59
+ if size is None:
60
+ if img_metas is not None:
61
+ size = img_metas[0]["ori_shape"][:2]
62
+ else:
63
+ size = img.shape[2:]
64
+ out = resize(input=out, size=size, mode="bilinear", align_corners=self.align_corners)
65
+ return out
66
+
67
+ def _decode_head_forward_train(self, img, x, img_metas, depth_gt, **kwargs):
68
+ """Run forward function and calculate loss for decode head in
69
+ training."""
70
+ losses = dict()
71
+ loss_decode = self.decode_head.forward_train(img, x, img_metas, depth_gt, **kwargs)
72
+ losses.update(add_prefix(loss_decode, "decode"))
73
+ return losses
74
+
75
+ def _decode_head_forward_test(self, x, img_metas):
76
+ """Run forward function and calculate loss for decode head in
77
+ inference."""
78
+ depth_pred = self.decode_head.forward_test(x, img_metas)
79
+ return depth_pred
80
+
81
+ def forward_dummy(self, img):
82
+ """Dummy forward function."""
83
+ depth = self.encode_decode(img, None)
84
+
85
+ return depth
86
+
87
+ def forward_train(self, img, img_metas, depth_gt, **kwargs):
88
+ """Forward function for training.
89
+
90
+ Args:
91
+ img (Tensor): Input images.
92
+ img_metas (list[dict]): List of image info dict where each dict
93
+ has: 'img_shape', 'scale_factor', 'flip', and may also contain
94
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
95
+ For details on the values of these keys see
96
+ `depth/datasets/pipelines/formatting.py:Collect`.
97
+ depth_gt (Tensor): Depth gt
98
+ used if the architecture supports depth estimation task.
99
+
100
+ Returns:
101
+ dict[str, Tensor]: a dictionary of loss components
102
+ """
103
+
104
+ x = self.extract_feat(img)
105
+
106
+ losses = dict()
107
+
108
+ # the last of x saves the info from neck
109
+ loss_decode = self._decode_head_forward_train(img, x, img_metas, depth_gt, **kwargs)
110
+
111
+ losses.update(loss_decode)
112
+
113
+ return losses
114
+
115
+ def whole_inference(self, img, img_meta, rescale, size=None):
116
+ """Inference with full image."""
117
+ return self.encode_decode(img, img_meta, rescale, size=size)
118
+
119
+ def slide_inference(self, img, img_meta, rescale, stride, crop_size):
120
+ """Inference by sliding-window with overlap.
121
+
122
+ If h_crop > h_img or w_crop > w_img, the small patch will be used to
123
+ decode without padding.
124
+ """
125
+
126
+ h_stride, w_stride = stride
127
+ h_crop, w_crop = crop_size
128
+ batch_size, _, h_img, w_img = img.size()
129
+ h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1
130
+ w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1
131
+ preds = img.new_zeros((batch_size, 1, h_img, w_img))
132
+ count_mat = img.new_zeros((batch_size, 1, h_img, w_img))
133
+ for h_idx in range(h_grids):
134
+ for w_idx in range(w_grids):
135
+ y1 = h_idx * h_stride
136
+ x1 = w_idx * w_stride
137
+ y2 = min(y1 + h_crop, h_img)
138
+ x2 = min(x1 + w_crop, w_img)
139
+ y1 = max(y2 - h_crop, 0)
140
+ x1 = max(x2 - w_crop, 0)
141
+ crop_img = img[:, :, y1:y2, x1:x2]
142
+ depth_pred = self.encode_decode(crop_img, img_meta, rescale)
143
+ preds += F.pad(depth_pred, (int(x1), int(preds.shape[3] - x2), int(y1), int(preds.shape[2] - y2)))
144
+
145
+ count_mat[:, :, y1:y2, x1:x2] += 1
146
+ assert (count_mat == 0).sum() == 0
147
+ if torch.onnx.is_in_onnx_export():
148
+ # cast count_mat to constant while exporting to ONNX
149
+ count_mat = torch.from_numpy(count_mat.cpu().detach().numpy()).to(device=img.device)
150
+ preds = preds / count_mat
151
+ return preds
152
+
153
+ def inference(self, img, img_meta, rescale, size=None, mode="whole"):
154
+ """Inference with slide/whole style.
155
+
156
+ Args:
157
+ img (Tensor): The input image of shape (N, 3, H, W).
158
+ img_meta (dict): Image info dict where each dict has: 'img_shape',
159
+ 'scale_factor', 'flip', and may also contain
160
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
161
+ For details on the values of these keys see
162
+ `depth/datasets/pipelines/formatting.py:Collect`.
163
+ rescale (bool): Whether rescale back to original shape.
164
+
165
+ Returns:
166
+ Tensor: The output depth map.
167
+ """
168
+
169
+ assert mode in ["slide", "whole"]
170
+ ori_shape = img_meta[0]["ori_shape"]
171
+ assert all(_["ori_shape"] == ori_shape for _ in img_meta)
172
+ if mode == "slide":
173
+ depth_pred = self.slide_inference(img, img_meta, rescale)
174
+ else:
175
+ depth_pred = self.whole_inference(img, img_meta, rescale, size=size)
176
+ output = depth_pred
177
+ flip = img_meta[0]["flip"]
178
+ if flip:
179
+ flip_direction = img_meta[0]["flip_direction"]
180
+ assert flip_direction in ["horizontal", "vertical"]
181
+ if flip_direction == "horizontal":
182
+ output = output.flip(dims=(3,))
183
+ elif flip_direction == "vertical":
184
+ output = output.flip(dims=(2,))
185
+
186
+ return output
187
+
188
+ def simple_test(self, img, img_meta, rescale=True):
189
+ """Simple test with single image."""
190
+ depth_pred = self.inference(img, img_meta, rescale)
191
+ if torch.onnx.is_in_onnx_export():
192
+ # our inference backend only support 4D output
193
+ depth_pred = depth_pred.unsqueeze(0)
194
+ return depth_pred
195
+ depth_pred = depth_pred.cpu().numpy()
196
+ # unravel batch dim
197
+ depth_pred = list(depth_pred)
198
+ return depth_pred
199
+
200
+ def aug_test(self, imgs, img_metas, rescale=True):
201
+ """Test with augmentations.
202
+
203
+ Only rescale=True is supported.
204
+ """
205
+ # aug_test rescale all imgs back to ori_shape for now
206
+ assert rescale
207
+ # to save memory, we get augmented depth logit inplace
208
+ depth_pred = self.inference(imgs[0], img_metas[0], rescale)
209
+ for i in range(1, len(imgs)):
210
+ cur_depth_pred = self.inference(imgs[i], img_metas[i], rescale, size=depth_pred.shape[-2:])
211
+ depth_pred += cur_depth_pred
212
+ depth_pred /= len(imgs)
213
+ depth_pred = depth_pred.cpu().numpy()
214
+ # unravel batch dim
215
+ depth_pred = list(depth_pred)
216
+ return depth_pred
217
+
218
+ def forward_test(self, imgs, img_metas, **kwargs):
219
+ """
220
+ Args:
221
+ imgs (List[Tensor]): the outer list indicates test-time
222
+ augmentations and inner Tensor should have a shape NxCxHxW,
223
+ which contains all images in the batch.
224
+ img_metas (List[List[dict]]): the outer list indicates test-time
225
+ augs (multiscale, flip, etc.) and the inner list indicates
226
+ images in a batch.
227
+ """
228
+ for var, name in [(imgs, "imgs"), (img_metas, "img_metas")]:
229
+ if not isinstance(var, list):
230
+ raise TypeError(f"{name} must be a list, but got " f"{type(var)}")
231
+ num_augs = len(imgs)
232
+ if num_augs != len(img_metas):
233
+ raise ValueError(f"num of augmentations ({len(imgs)}) != " f"num of image meta ({len(img_metas)})")
234
+ # all images in the same aug batch all of the same ori_shape and pad
235
+ # shape
236
+ for img_meta in img_metas:
237
+ ori_shapes = [_["ori_shape"] for _ in img_meta]
238
+ assert all(shape == ori_shapes[0] for shape in ori_shapes)
239
+ img_shapes = [_["img_shape"] for _ in img_meta]
240
+ assert all(shape == img_shapes[0] for shape in img_shapes)
241
+ pad_shapes = [_["pad_shape"] for _ in img_meta]
242
+ assert all(shape == pad_shapes[0] for shape in pad_shapes)
243
+
244
+ if num_augs == 1:
245
+ return self.simple_test(imgs[0], img_metas[0], **kwargs)
246
+ else:
247
+ return self.aug_test(imgs, img_metas, **kwargs)
248
+
249
+ def forward(self, img, img_metas, return_loss=True, **kwargs):
250
+ """Calls either :func:`forward_train` or :func:`forward_test` depending
251
+ on whether ``return_loss`` is ``True``.
252
+
253
+ Note this setting will change the expected inputs. When
254
+ ``return_loss=True``, img and img_meta are single-nested (i.e. Tensor
255
+ and List[dict]), and when ``resturn_loss=False``, img and img_meta
256
+ should be double nested (i.e. List[Tensor], List[List[dict]]), with
257
+ the outer list indicating test time augmentations.
258
+ """
259
+ if return_loss:
260
+ return self.forward_train(img, img_metas, **kwargs)
261
+ else:
262
+ return self.forward_test(img, img_metas, **kwargs)
263
+
264
+ def train_step(self, data_batch, optimizer, **kwargs):
265
+ """The iteration step during training.
266
+
267
+ This method defines an iteration step during training, except for the
268
+ back propagation and optimizer updating, which are done in an optimizer
269
+ hook. Note that in some complicated cases or models, the whole process
270
+ including back propagation and optimizer updating is also defined in
271
+ this method, such as GAN.
272
+
273
+ Args:
274
+ data (dict): The output of dataloader.
275
+ optimizer (:obj:`torch.optim.Optimizer` | dict): The optimizer of
276
+ runner is passed to ``train_step()``. This argument is unused
277
+ and reserved.
278
+
279
+ Returns:
280
+ dict: It should contain at least 3 keys: ``loss``, ``log_vars``,
281
+ ``num_samples``.
282
+ ``loss`` is a tensor for back propagation, which can be a
283
+ weighted sum of multiple losses.
284
+ ``log_vars`` contains all the variables to be sent to the
285
+ logger.
286
+ ``num_samples`` indicates the batch size (when the model is
287
+ DDP, it means the batch size on each GPU), which is used for
288
+ averaging the logs.
289
+ """
290
+ losses = self(**data_batch)
291
+
292
+ # split losses and images
293
+ real_losses = {}
294
+ log_imgs = {}
295
+ for k, v in losses.items():
296
+ if "img" in k:
297
+ log_imgs[k] = v
298
+ else:
299
+ real_losses[k] = v
300
+
301
+ loss, log_vars = self._parse_losses(real_losses)
302
+
303
+ outputs = dict(loss=loss, log_vars=log_vars, num_samples=len(data_batch["img_metas"]), log_imgs=log_imgs)
304
+
305
+ return outputs
306
+
307
+ def val_step(self, data_batch, **kwargs):
308
+ """The iteration step during validation.
309
+
310
+ This method shares the same signature as :func:`train_step`, but used
311
+ during val epochs. Note that the evaluation after training epochs is
312
+ not implemented with this method, but an evaluation hook.
313
+ """
314
+ output = self(**data_batch, **kwargs)
315
+ return output
316
+
317
+ @staticmethod
318
+ def _parse_losses(losses):
319
+ import torch.distributed as dist
320
+
321
+ """Parse the raw outputs (losses) of the network.
322
+
323
+ Args:
324
+ losses (dict): Raw output of the network, which usually contain
325
+ losses and other necessary information.
326
+
327
+ Returns:
328
+ tuple[Tensor, dict]: (loss, log_vars), loss is the loss tensor
329
+ which may be a weighted sum of all losses, log_vars contains
330
+ all the variables to be sent to the logger.
331
+ """
332
+ log_vars = OrderedDict()
333
+ for loss_name, loss_value in losses.items():
334
+ if isinstance(loss_value, torch.Tensor):
335
+ log_vars[loss_name] = loss_value.mean()
336
+ elif isinstance(loss_value, list):
337
+ log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value)
338
+ else:
339
+ raise TypeError(f"{loss_name} is not a tensor or list of tensors")
340
+
341
+ loss = sum(_value for _key, _value in log_vars.items() if "loss" in _key)
342
+
343
+ log_vars["loss"] = loss
344
+ for loss_name, loss_value in log_vars.items():
345
+ # reduce loss when distributed training
346
+ if dist.is_available() and dist.is_initialized():
347
+ loss_value = loss_value.data.clone()
348
+ dist.all_reduce(loss_value.div_(dist.get_world_size()))
349
+ log_vars[loss_name] = loss_value.item()
350
+
351
+ return loss, log_vars
torch_hub/facebookresearch_dinov2_main/dinov2/hub/depth/ops.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import warnings
7
+
8
+ import torch.nn.functional as F
9
+
10
+
11
+ def resize(input, size=None, scale_factor=None, mode="nearest", align_corners=None, warning=False):
12
+ if warning:
13
+ if size is not None and align_corners:
14
+ input_h, input_w = tuple(int(x) for x in input.shape[2:])
15
+ output_h, output_w = tuple(int(x) for x in size)
16
+ if output_h > input_h or output_w > output_h:
17
+ if (
18
+ (output_h > 1 and output_w > 1 and input_h > 1 and input_w > 1)
19
+ and (output_h - 1) % (input_h - 1)
20
+ and (output_w - 1) % (input_w - 1)
21
+ ):
22
+ warnings.warn(
23
+ f"When align_corners={align_corners}, "
24
+ "the output would more aligned if "
25
+ f"input size {(input_h, input_w)} is `x+1` and "
26
+ f"out size {(output_h, output_w)} is `nx+1`"
27
+ )
28
+ return F.interpolate(input, size, scale_factor, mode, align_corners)
torch_hub/facebookresearch_dinov2_main/dinov2/hub/utils.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import itertools
7
+ import math
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+
13
+
14
+ _DINOV2_BASE_URL = "https://dl.fbaipublicfiles.com/dinov2"
15
+
16
+
17
+ def _make_dinov2_model_name(arch_name: str, patch_size: int, num_register_tokens: int = 0) -> str:
18
+ compact_arch_name = arch_name.replace("_", "")[:4]
19
+ registers_suffix = f"_reg{num_register_tokens}" if num_register_tokens else ""
20
+ return f"dinov2_{compact_arch_name}{patch_size}{registers_suffix}"
21
+
22
+
23
+ class CenterPadding(nn.Module):
24
+ def __init__(self, multiple):
25
+ super().__init__()
26
+ self.multiple = multiple
27
+
28
+ def _get_pad(self, size):
29
+ new_size = math.ceil(size / self.multiple) * self.multiple
30
+ pad_size = new_size - size
31
+ pad_size_left = pad_size // 2
32
+ pad_size_right = pad_size - pad_size_left
33
+ return pad_size_left, pad_size_right
34
+
35
+ @torch.inference_mode()
36
+ def forward(self, x):
37
+ pads = list(itertools.chain.from_iterable(self._get_pad(m) for m in x.shape[:1:-1]))
38
+ output = F.pad(x, pads)
39
+ return output
torch_hub/facebookresearch_dinov2_main/dinov2/hub/xray_dino/__pycache__/backbones.cpython-310.pyc ADDED
Binary file (847 Bytes). View file
 
torch_hub/facebookresearch_dinov2_main/dinov2/hub/xray_dino/backbones.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the licence
4
+ # found in the LICENSE_XRAY_DINO_MODEL file in the root directory of this source tree.
5
+
6
+ from typing import Union
7
+
8
+ from ..backbones import Weights, _make_dinov2_model
9
+
10
+
11
+ def xray_dino_vitl16(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.XRAY_DINO, **kwargs):
12
+ """
13
+ XRay-DINO ViT-L/16 model (optionally) pretrained on the XRay-DINO dataset.
14
+ """
15
+ return _make_dinov2_model(
16
+ arch_name="vit_large",
17
+ patch_size=16,
18
+ img_size=512,
19
+ num_register_tokens=0,
20
+ interpolate_antialias=False,
21
+ interpolate_offset=0.1,
22
+ block_chunks=4,
23
+ pretrained=pretrained,
24
+ weights=weights,
25
+ hash="ad31c2b0",
26
+ check_hash=True,
27
+ **kwargs,
28
+ )
torch_hub/facebookresearch_dinov2_main/dinov2/layers/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ from .dino_head import DINOHead
7
+ from .layer_scale import LayerScale
8
+ from .mlp import Mlp
9
+ from .patch_embed import PatchEmbed
10
+ from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused, SwiGLUFFNAligned
11
+ from .block import NestedTensorBlock, CausalAttentionBlock
12
+ from .attention import Attention, MemEffAttention
torch_hub/facebookresearch_dinov2_main/dinov2/layers/__pycache__/dino_head.cpython-310.pyc ADDED
Binary file (1.99 kB). View file
 
torch_hub/facebookresearch_dinov2_main/dinov2/layers/__pycache__/layer_scale.cpython-310.pyc ADDED
Binary file (1.35 kB). View file
 
torch_hub/facebookresearch_dinov2_main/dinov2/layers/__pycache__/swiglu_ffn.cpython-310.pyc ADDED
Binary file (3.25 kB). View file
 
torch_hub/facebookresearch_dinov2_main/dinov2/layers/attention.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
9
+
10
+ import logging
11
+ import os
12
+ import warnings
13
+
14
+ import torch
15
+ from torch import nn, Tensor
16
+
17
+
18
+ logger = logging.getLogger("dinov2")
19
+
20
+
21
+ XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
22
+ try:
23
+ if XFORMERS_ENABLED:
24
+ from xformers.ops import memory_efficient_attention, unbind
25
+
26
+ XFORMERS_AVAILABLE = True
27
+ warnings.warn("xFormers is available (Attention)")
28
+ else:
29
+ warnings.warn("xFormers is disabled (Attention)")
30
+ raise ImportError
31
+ except ImportError:
32
+ XFORMERS_AVAILABLE = False
33
+ warnings.warn("xFormers is not available (Attention)")
34
+
35
+
36
+ class Attention(nn.Module):
37
+ def __init__(
38
+ self,
39
+ dim: int,
40
+ num_heads: int = 8,
41
+ qkv_bias: bool = False,
42
+ proj_bias: bool = True,
43
+ attn_drop: float = 0.0,
44
+ proj_drop: float = 0.0,
45
+ ) -> None:
46
+ super().__init__()
47
+ self.dim = dim
48
+ self.num_heads = num_heads
49
+ head_dim = dim // num_heads
50
+ self.scale = head_dim**-0.5
51
+
52
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
53
+ self.attn_drop = attn_drop
54
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
55
+ self.proj_drop = nn.Dropout(proj_drop)
56
+
57
+ def init_weights(
58
+ self, init_attn_std: float | None = None, init_proj_std: float | None = None, factor: float = 1.0
59
+ ) -> None:
60
+ init_attn_std = init_attn_std or (self.dim**-0.5)
61
+ init_proj_std = init_proj_std or init_attn_std * factor
62
+ nn.init.normal_(self.qkv.weight, std=init_attn_std)
63
+ nn.init.normal_(self.proj.weight, std=init_proj_std)
64
+ if self.qkv.bias is not None:
65
+ nn.init.zeros_(self.qkv.bias)
66
+ if self.proj.bias is not None:
67
+ nn.init.zeros_(self.proj.bias)
68
+
69
+ def forward(self, x: Tensor, is_causal: bool = False) -> Tensor:
70
+ B, N, C = x.shape
71
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
72
+ q, k, v = torch.unbind(qkv, 2)
73
+ q, k, v = [t.transpose(1, 2) for t in [q, k, v]]
74
+ x = nn.functional.scaled_dot_product_attention(
75
+ q, k, v, attn_mask=None, dropout_p=self.attn_drop if self.training else 0, is_causal=is_causal
76
+ )
77
+ x = x.transpose(1, 2).contiguous().view(B, N, C)
78
+ x = self.proj_drop(self.proj(x))
79
+ return x
80
+
81
+
82
+ class MemEffAttention(Attention):
83
+ def forward(self, x: Tensor, attn_bias=None) -> Tensor:
84
+ if not XFORMERS_AVAILABLE:
85
+ if attn_bias is not None:
86
+ raise AssertionError("xFormers is required for using nested tensors")
87
+ return super().forward(x)
88
+
89
+ B, N, C = x.shape
90
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
91
+
92
+ q, k, v = unbind(qkv, 2)
93
+
94
+ x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
95
+ x = x.reshape([B, N, C])
96
+
97
+ x = self.proj(x)
98
+ x = self.proj_drop(x)
99
+ return x
torch_hub/facebookresearch_dinov2_main/dinov2/layers/patch_embed.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
9
+
10
+ from typing import Callable, Optional, Tuple, Union
11
+
12
+ from torch import Tensor
13
+ import torch.nn as nn
14
+
15
+
16
+ def make_2tuple(x):
17
+ if isinstance(x, tuple):
18
+ assert len(x) == 2
19
+ return x
20
+
21
+ assert isinstance(x, int)
22
+ return (x, x)
23
+
24
+
25
+ class PatchEmbed(nn.Module):
26
+ """
27
+ 2D image to patch embedding: (B,C,H,W) -> (B,N,D)
28
+
29
+ Args:
30
+ img_size: Image size.
31
+ patch_size: Patch token size.
32
+ in_chans: Number of input image channels.
33
+ embed_dim: Number of linear projection output channels.
34
+ norm_layer: Normalization layer.
35
+ """
36
+
37
+ def __init__(
38
+ self,
39
+ img_size: Union[int, Tuple[int, int]] = 224,
40
+ patch_size: Union[int, Tuple[int, int]] = 16,
41
+ in_chans: int = 3,
42
+ embed_dim: int = 768,
43
+ norm_layer: Optional[Callable] = None,
44
+ flatten_embedding: bool = True,
45
+ ) -> None:
46
+ super().__init__()
47
+
48
+ image_HW = make_2tuple(img_size)
49
+ patch_HW = make_2tuple(patch_size)
50
+ patch_grid_size = (
51
+ image_HW[0] // patch_HW[0],
52
+ image_HW[1] // patch_HW[1],
53
+ )
54
+
55
+ self.img_size = image_HW
56
+ self.patch_size = patch_HW
57
+ self.patches_resolution = patch_grid_size
58
+ self.num_patches = patch_grid_size[0] * patch_grid_size[1]
59
+
60
+ self.in_chans = in_chans
61
+ self.embed_dim = embed_dim
62
+
63
+ self.flatten_embedding = flatten_embedding
64
+
65
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
66
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
67
+
68
+ def forward(self, x: Tensor) -> Tensor:
69
+ _, _, H, W = x.shape
70
+ patch_H, patch_W = self.patch_size
71
+
72
+ assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
73
+ assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
74
+
75
+ x = self.proj(x) # B C H W
76
+ H, W = x.size(2), x.size(3)
77
+ x = x.flatten(2).transpose(1, 2) # B HW C
78
+ x = self.norm(x)
79
+ if not self.flatten_embedding:
80
+ x = x.reshape(-1, H, W, self.embed_dim) # B H W C
81
+ return x
82
+
83
+ def flops(self) -> float:
84
+ Ho, Wo = self.patches_resolution
85
+ flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
86
+ if self.norm is not None:
87
+ flops += Ho * Wo * self.embed_dim
88
+ return flops
torch_hub/facebookresearch_dinov2_main/dinov2/logging/__init__.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import functools
7
+ import logging
8
+ import os
9
+ import sys
10
+ from typing import Optional
11
+
12
+ import dinov2.distributed as distributed
13
+ from .helpers import MetricLogger, SmoothedValue
14
+
15
+
16
+ # So that calling _configure_logger multiple times won't add many handlers
17
+ @functools.lru_cache()
18
+ def _configure_logger(
19
+ name: Optional[str] = None,
20
+ *,
21
+ level: int = logging.DEBUG,
22
+ output: Optional[str] = None,
23
+ ):
24
+ """
25
+ Configure a logger.
26
+
27
+ Adapted from Detectron2.
28
+
29
+ Args:
30
+ name: The name of the logger to configure.
31
+ level: The logging level to use.
32
+ output: A file name or a directory to save log. If None, will not save log file.
33
+ If ends with ".txt" or ".log", assumed to be a file name.
34
+ Otherwise, logs will be saved to `output/log.txt`.
35
+
36
+ Returns:
37
+ The configured logger.
38
+ """
39
+
40
+ logger = logging.getLogger(name)
41
+ logger.setLevel(level)
42
+ logger.propagate = False
43
+
44
+ # Loosely match Google glog format:
45
+ # [IWEF]yyyymmdd hh:mm:ss.uuuuuu threadid file:line] msg
46
+ # but use a shorter timestamp and include the logger name:
47
+ # [IWEF]yyyymmdd hh:mm:ss logger threadid file:line] msg
48
+ fmt_prefix = "%(levelname).1s%(asctime)s %(process)s %(name)s %(filename)s:%(lineno)s] "
49
+ fmt_message = "%(message)s"
50
+ fmt = fmt_prefix + fmt_message
51
+ datefmt = "%Y%m%d %H:%M:%S"
52
+ formatter = logging.Formatter(fmt=fmt, datefmt=datefmt)
53
+
54
+ # stdout logging for main worker only
55
+ if distributed.is_main_process():
56
+ handler = logging.StreamHandler(stream=sys.stdout)
57
+ handler.setLevel(logging.DEBUG)
58
+ handler.setFormatter(formatter)
59
+ logger.addHandler(handler)
60
+
61
+ # file logging for all workers
62
+ if output:
63
+ if os.path.splitext(output)[-1] in (".txt", ".log"):
64
+ filename = output
65
+ else:
66
+ filename = os.path.join(output, "logs", "log.txt")
67
+
68
+ if not distributed.is_main_process():
69
+ global_rank = distributed.get_global_rank()
70
+ filename = filename + ".rank{}".format(global_rank)
71
+
72
+ os.makedirs(os.path.dirname(filename), exist_ok=True)
73
+
74
+ handler = logging.StreamHandler(open(filename, "a"))
75
+ handler.setLevel(logging.DEBUG)
76
+ handler.setFormatter(formatter)
77
+ logger.addHandler(handler)
78
+
79
+ return logger
80
+
81
+
82
+ def setup_logging(
83
+ output: Optional[str] = None,
84
+ *,
85
+ name: Optional[str] = None,
86
+ level: int = logging.DEBUG,
87
+ capture_warnings: bool = True,
88
+ ) -> None:
89
+ """
90
+ Setup logging.
91
+
92
+ Args:
93
+ output: A file name or a directory to save log files. If None, log
94
+ files will not be saved. If output ends with ".txt" or ".log", it
95
+ is assumed to be a file name.
96
+ Otherwise, logs will be saved to `output/log.txt`.
97
+ name: The name of the logger to configure, by default the root logger.
98
+ level: The logging level to use.
99
+ capture_warnings: Whether warnings should be captured as logs.
100
+ """
101
+ logging.captureWarnings(capture_warnings)
102
+ _configure_logger(name, level=level, output=output)
torch_hub/facebookresearch_dinov2_main/dinov2/logging/helpers.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ from collections import defaultdict, deque
7
+ import datetime
8
+ import json
9
+ import logging
10
+ import time
11
+
12
+ import torch
13
+
14
+ import dinov2.distributed as distributed
15
+
16
+
17
+ logger = logging.getLogger("dinov2")
18
+
19
+
20
+ class MetricLogger(object):
21
+ def __init__(self, delimiter="\t", output_file=None):
22
+ self.meters = defaultdict(SmoothedValue)
23
+ self.delimiter = delimiter
24
+ self.output_file = output_file
25
+
26
+ def update(self, **kwargs):
27
+ for k, v in kwargs.items():
28
+ if isinstance(v, torch.Tensor):
29
+ v = v.item()
30
+ assert isinstance(v, (float, int))
31
+ self.meters[k].update(v)
32
+
33
+ def __getattr__(self, attr):
34
+ if attr in self.meters:
35
+ return self.meters[attr]
36
+ if attr in self.__dict__:
37
+ return self.__dict__[attr]
38
+ raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, attr))
39
+
40
+ def __str__(self):
41
+ loss_str = []
42
+ for name, meter in self.meters.items():
43
+ loss_str.append("{}: {}".format(name, str(meter)))
44
+ return self.delimiter.join(loss_str)
45
+
46
+ def synchronize_between_processes(self):
47
+ for meter in self.meters.values():
48
+ meter.synchronize_between_processes()
49
+
50
+ def add_meter(self, name, meter):
51
+ self.meters[name] = meter
52
+
53
+ def dump_in_output_file(self, iteration, iter_time, data_time):
54
+ if self.output_file is None or not distributed.is_main_process():
55
+ return
56
+ dict_to_dump = dict(
57
+ iteration=iteration,
58
+ iter_time=iter_time,
59
+ data_time=data_time,
60
+ )
61
+ dict_to_dump.update({k: v.median for k, v in self.meters.items()})
62
+ with open(self.output_file, "a") as f:
63
+ f.write(json.dumps(dict_to_dump) + "\n")
64
+ pass
65
+
66
+ def log_every(self, iterable, print_freq, header=None, n_iterations=None, start_iteration=0):
67
+ i = start_iteration
68
+ if not header:
69
+ header = ""
70
+ start_time = time.time()
71
+ end = time.time()
72
+ iter_time = SmoothedValue(fmt="{avg:.6f}")
73
+ data_time = SmoothedValue(fmt="{avg:.6f}")
74
+
75
+ if n_iterations is None:
76
+ n_iterations = len(iterable)
77
+
78
+ space_fmt = ":" + str(len(str(n_iterations))) + "d"
79
+
80
+ log_list = [
81
+ header,
82
+ "[{0" + space_fmt + "}/{1}]",
83
+ "eta: {eta}",
84
+ "{meters}",
85
+ "time: {time}",
86
+ "data: {data}",
87
+ ]
88
+ if torch.cuda.is_available():
89
+ log_list += ["max mem: {memory:.0f}"]
90
+
91
+ log_msg = self.delimiter.join(log_list)
92
+ MB = 1024.0 * 1024.0
93
+ for obj in iterable:
94
+ data_time.update(time.time() - end)
95
+ yield obj
96
+ iter_time.update(time.time() - end)
97
+ if i % print_freq == 0 or i == n_iterations - 1:
98
+ self.dump_in_output_file(iteration=i, iter_time=iter_time.avg, data_time=data_time.avg)
99
+ eta_seconds = iter_time.global_avg * (n_iterations - i)
100
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
101
+ if torch.cuda.is_available():
102
+ logger.info(
103
+ log_msg.format(
104
+ i,
105
+ n_iterations,
106
+ eta=eta_string,
107
+ meters=str(self),
108
+ time=str(iter_time),
109
+ data=str(data_time),
110
+ memory=torch.cuda.max_memory_allocated() / MB,
111
+ )
112
+ )
113
+ else:
114
+ logger.info(
115
+ log_msg.format(
116
+ i,
117
+ n_iterations,
118
+ eta=eta_string,
119
+ meters=str(self),
120
+ time=str(iter_time),
121
+ data=str(data_time),
122
+ )
123
+ )
124
+ i += 1
125
+ end = time.time()
126
+ if i >= n_iterations:
127
+ break
128
+ total_time = time.time() - start_time
129
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
130
+ logger.info("{} Total time: {} ({:.6f} s / it)".format(header, total_time_str, total_time / n_iterations))
131
+
132
+
133
+ class SmoothedValue:
134
+ """Track a series of values and provide access to smoothed values over a
135
+ window or the global series average.
136
+ """
137
+
138
+ def __init__(self, window_size=20, fmt=None):
139
+ if fmt is None:
140
+ fmt = "{median:.4f} ({global_avg:.4f})"
141
+ self.deque = deque(maxlen=window_size)
142
+ self.total = 0.0
143
+ self.count = 0
144
+ self.fmt = fmt
145
+
146
+ def update(self, value, num=1):
147
+ self.deque.append(value)
148
+ self.count += num
149
+ self.total += value * num
150
+
151
+ def synchronize_between_processes(self):
152
+ """
153
+ Distributed synchronization of the metric
154
+ Warning: does not synchronize the deque!
155
+ """
156
+ if not distributed.is_enabled():
157
+ return
158
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
159
+ torch.distributed.barrier()
160
+ torch.distributed.all_reduce(t)
161
+ t = t.tolist()
162
+ self.count = int(t[0])
163
+ self.total = t[1]
164
+
165
+ @property
166
+ def median(self):
167
+ d = torch.tensor(list(self.deque))
168
+ return d.median().item()
169
+
170
+ @property
171
+ def avg(self):
172
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
173
+ return d.mean().item()
174
+
175
+ @property
176
+ def global_avg(self):
177
+ return self.total / self.count
178
+
179
+ @property
180
+ def max(self):
181
+ return max(self.deque)
182
+
183
+ @property
184
+ def value(self):
185
+ return self.deque[-1]
186
+
187
+ def __str__(self):
188
+ return self.fmt.format(
189
+ median=self.median,
190
+ avg=self.avg,
191
+ global_avg=self.global_avg,
192
+ max=self.max,
193
+ value=self.value,
194
+ )
torch_hub/facebookresearch_dinov2_main/dinov2/loss/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ from .dino_clstoken_loss import DINOLoss
7
+ from .ibot_patch_loss import iBOTPatchLoss
8
+ from .koleo_loss import KoLeoLoss
torch_hub/facebookresearch_dinov2_main/dinov2/loss/ibot_patch_loss.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ import torch.distributed as dist
8
+ import torch.nn.functional as F
9
+ from torch import nn
10
+
11
+ import logging
12
+
13
+
14
+ logger = logging.getLogger("dinov2")
15
+
16
+
17
+ try:
18
+ from xformers.ops import cross_entropy
19
+
20
+ def lossfunc(t, s, temp):
21
+ s = s.float()
22
+ t = t.float()
23
+ if s.ndim == 2:
24
+ return -cross_entropy(s.unsqueeze(0), t.unsqueeze(0), temp, bw_inplace=True).squeeze(0)
25
+ elif s.ndim == 3:
26
+ return -cross_entropy(s, t, temp, bw_inplace=True)
27
+
28
+ except ImportError:
29
+
30
+ def lossfunc(t, s, temp):
31
+ return torch.sum(t * F.log_softmax(s / temp, dim=-1), dim=-1)
32
+
33
+
34
+ class iBOTPatchLoss(nn.Module):
35
+ def __init__(self, patch_out_dim, student_temp=0.1, center_momentum=0.9):
36
+ super().__init__()
37
+ self.student_temp = student_temp
38
+ self.center_momentum = center_momentum
39
+ self.register_buffer("center", torch.zeros(1, 1, patch_out_dim))
40
+ self.updated = True
41
+ self.reduce_handle = None
42
+ self.len_teacher_patch_tokens = None
43
+ self.async_batch_center = None
44
+
45
+ @torch.no_grad()
46
+ def softmax_center_teacher(self, teacher_patch_tokens, teacher_temp):
47
+ self.apply_center_update()
48
+ # teacher centering and sharpening
49
+ #
50
+ # WARNING:
51
+ # as self.center is a float32, everything gets casted to float32 afterwards
52
+ #
53
+ # teacher_patch_tokens = teacher_patch_tokens.float()
54
+ # return F.softmax((teacher_patch_tokens.sub_(self.center.to(teacher_patch_tokens.dtype))).mul_(1 / teacher_temp), dim=-1)
55
+
56
+ return F.softmax((teacher_patch_tokens - self.center) / teacher_temp, dim=-1)
57
+
58
+ # this is experimental, keep everything in float16 and let's see what happens:
59
+ # return F.softmax((teacher_patch_tokens.sub_(self.center)) / teacher_temp, dim=-1)
60
+
61
+ @torch.no_grad()
62
+ def sinkhorn_knopp_teacher(self, teacher_output, teacher_temp, n_masked_patches_tensor, n_iterations=3):
63
+ teacher_output = teacher_output.float()
64
+ # world_size = dist.get_world_size() if dist.is_initialized() else 1
65
+ Q = torch.exp(teacher_output / teacher_temp).t() # Q is K-by-B for consistency with notations from our paper
66
+ # B = Q.shape[1] * world_size # number of samples to assign
67
+ B = n_masked_patches_tensor
68
+ dist.all_reduce(B)
69
+ K = Q.shape[0] # how many prototypes
70
+
71
+ # make the matrix sums to 1
72
+ sum_Q = torch.sum(Q)
73
+ if dist.is_initialized():
74
+ dist.all_reduce(sum_Q)
75
+ Q /= sum_Q
76
+
77
+ for it in range(n_iterations):
78
+ # normalize each row: total weight per prototype must be 1/K
79
+ sum_of_rows = torch.sum(Q, dim=1, keepdim=True)
80
+ if dist.is_initialized():
81
+ dist.all_reduce(sum_of_rows)
82
+ Q /= sum_of_rows
83
+ Q /= K
84
+
85
+ # normalize each column: total weight per sample must be 1/B
86
+ Q /= torch.sum(Q, dim=0, keepdim=True)
87
+ Q /= B
88
+
89
+ Q *= B # the columns must sum to 1 so that Q is an assignment
90
+ return Q.t()
91
+
92
+ def forward(self, student_patch_tokens, teacher_patch_tokens, student_masks_flat):
93
+ """
94
+ Cross-entropy between softmax outputs of the teacher and student networks.
95
+ student_patch_tokens: (B, N, D) tensor
96
+ teacher_patch_tokens: (B, N, D) tensor
97
+ student_masks_flat: (B, N) tensor
98
+ """
99
+ t = teacher_patch_tokens
100
+ s = student_patch_tokens
101
+ loss = torch.sum(t * F.log_softmax(s / self.student_temp, dim=-1), dim=-1)
102
+ loss = torch.sum(loss * student_masks_flat.float(), dim=-1) / student_masks_flat.sum(dim=-1).clamp(min=1.0)
103
+ return -loss.mean()
104
+
105
+ def forward_masked(
106
+ self,
107
+ student_patch_tokens_masked,
108
+ teacher_patch_tokens_masked,
109
+ student_masks_flat,
110
+ n_masked_patches=None,
111
+ masks_weight=None,
112
+ ):
113
+ t = teacher_patch_tokens_masked
114
+ s = student_patch_tokens_masked
115
+ # loss = torch.sum(t * F.log_softmax(s / self.student_temp, dim=-1), dim=-1)
116
+ loss = lossfunc(t, s, self.student_temp)
117
+ if masks_weight is None:
118
+ masks_weight = (
119
+ (1 / student_masks_flat.sum(-1).clamp(min=1.0))
120
+ .unsqueeze(-1)
121
+ .expand_as(student_masks_flat)[student_masks_flat]
122
+ )
123
+ if n_masked_patches is not None:
124
+ loss = loss[:n_masked_patches]
125
+ loss = loss * masks_weight
126
+ return -loss.sum() / student_masks_flat.shape[0]
127
+
128
+ @torch.no_grad()
129
+ def update_center(self, teacher_patch_tokens):
130
+ self.reduce_center_update(teacher_patch_tokens)
131
+
132
+ @torch.no_grad()
133
+ def reduce_center_update(self, teacher_patch_tokens):
134
+ self.updated = False
135
+ self.len_teacher_patch_tokens = len(teacher_patch_tokens)
136
+ self.async_batch_center = torch.sum(teacher_patch_tokens.mean(1), dim=0, keepdim=True)
137
+ if dist.is_initialized():
138
+ self.reduce_handle = dist.all_reduce(self.async_batch_center, async_op=True)
139
+
140
+ @torch.no_grad()
141
+ def apply_center_update(self):
142
+ if self.updated is False:
143
+ world_size = dist.get_world_size() if dist.is_initialized() else 1
144
+
145
+ if self.reduce_handle is not None:
146
+ self.reduce_handle.wait()
147
+ _t = self.async_batch_center / (self.len_teacher_patch_tokens * world_size)
148
+
149
+ self.center = self.center * self.center_momentum + _t * (1 - self.center_momentum)
150
+
151
+ self.updated = True
torch_hub/facebookresearch_dinov2_main/dinov2/models/vision_transformer.py ADDED
@@ -0,0 +1,428 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
9
+
10
+ from functools import partial
11
+ import math
12
+ import logging
13
+ from typing import Sequence, Tuple, Union, Callable
14
+
15
+ import numpy as np
16
+ import torch
17
+ import torch.nn as nn
18
+ import torch.utils.checkpoint
19
+ from torch.nn.init import trunc_normal_
20
+
21
+ from dinov2.layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block
22
+
23
+
24
+ logger = logging.getLogger("dinov2")
25
+
26
+
27
+ def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
28
+ if not depth_first and include_root:
29
+ fn(module=module, name=name)
30
+ for child_name, child_module in module.named_children():
31
+ child_name = ".".join((name, child_name)) if name else child_name
32
+ named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
33
+ if depth_first and include_root:
34
+ fn(module=module, name=name)
35
+ return module
36
+
37
+
38
+ class BlockChunk(nn.ModuleList):
39
+ def forward(self, x):
40
+ for b in self:
41
+ x = b(x)
42
+ return x
43
+
44
+
45
+ class DinoVisionTransformer(nn.Module):
46
+ def __init__(
47
+ self,
48
+ img_size=224,
49
+ patch_size=16,
50
+ in_chans=3,
51
+ embed_dim=768,
52
+ depth=12,
53
+ num_heads=12,
54
+ mlp_ratio=4.0,
55
+ qkv_bias=True,
56
+ ffn_bias=True,
57
+ proj_bias=True,
58
+ drop_path_rate=0.0,
59
+ drop_path_uniform=False,
60
+ init_values=None, # for layerscale: None or 0 => no layerscale
61
+ embed_layer=PatchEmbed,
62
+ act_layer=nn.GELU,
63
+ block_fn=Block,
64
+ ffn_layer="mlp",
65
+ block_chunks=1,
66
+ num_register_tokens=0,
67
+ interpolate_antialias=False,
68
+ interpolate_offset=0.1,
69
+ channel_adaptive=False,
70
+ ):
71
+ """
72
+ Args:
73
+ img_size (int, tuple): input image size
74
+ patch_size (int, tuple): patch size
75
+ in_chans (int): number of input channels
76
+ embed_dim (int): embedding dimension
77
+ depth (int): depth of transformer
78
+ num_heads (int): number of attention heads
79
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
80
+ qkv_bias (bool): enable bias for qkv if True
81
+ proj_bias (bool): enable bias for proj in attn if True
82
+ ffn_bias (bool): enable bias for ffn if True
83
+ drop_path_rate (float): stochastic depth rate
84
+ drop_path_uniform (bool): apply uniform drop rate across blocks
85
+ weight_init (str): weight init scheme
86
+ init_values (float): layer-scale init values
87
+ embed_layer (nn.Module): patch embedding layer
88
+ act_layer (nn.Module): MLP activation layer
89
+ block_fn (nn.Module): transformer block class
90
+ ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
91
+ block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
92
+ num_register_tokens: (int) number of extra cls tokens (so-called "registers")
93
+ interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings
94
+ interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings
95
+ """
96
+ super().__init__()
97
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
98
+
99
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
100
+ self.num_tokens = 1
101
+ self.n_blocks = depth
102
+ self.num_heads = num_heads
103
+ self.patch_size = patch_size
104
+ self.num_register_tokens = num_register_tokens
105
+ self.interpolate_antialias = interpolate_antialias
106
+ self.interpolate_offset = interpolate_offset
107
+ self.bag_of_channels = channel_adaptive
108
+
109
+ self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
110
+ num_patches = self.patch_embed.num_patches
111
+
112
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
113
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
114
+ assert num_register_tokens >= 0
115
+ self.register_tokens = (
116
+ nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None
117
+ )
118
+
119
+ if drop_path_uniform is True:
120
+ dpr = [drop_path_rate] * depth
121
+ else:
122
+ dpr = np.linspace(0, drop_path_rate, depth).tolist() # stochastic depth decay rule
123
+
124
+ if ffn_layer == "mlp":
125
+ logger.info("using MLP layer as FFN")
126
+ ffn_layer = Mlp
127
+ elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
128
+ logger.info("using SwiGLU layer as FFN")
129
+ ffn_layer = SwiGLUFFNFused
130
+ elif ffn_layer == "identity":
131
+ logger.info("using Identity layer as FFN")
132
+
133
+ def f(*args, **kwargs):
134
+ return nn.Identity()
135
+
136
+ ffn_layer = f
137
+ else:
138
+ raise NotImplementedError
139
+
140
+ blocks_list = [
141
+ block_fn(
142
+ dim=embed_dim,
143
+ num_heads=num_heads,
144
+ mlp_ratio=mlp_ratio,
145
+ qkv_bias=qkv_bias,
146
+ proj_bias=proj_bias,
147
+ ffn_bias=ffn_bias,
148
+ drop_path=dpr[i],
149
+ norm_layer=norm_layer,
150
+ act_layer=act_layer,
151
+ ffn_layer=ffn_layer,
152
+ init_values=init_values,
153
+ )
154
+ for i in range(depth)
155
+ ]
156
+ if block_chunks > 0:
157
+ self.chunked_blocks = True
158
+ chunked_blocks = []
159
+ chunksize = depth // block_chunks
160
+ for i in range(0, depth, chunksize):
161
+ # this is to keep the block index consistent if we chunk the block list
162
+ chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
163
+ self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
164
+ else:
165
+ self.chunked_blocks = False
166
+ self.blocks = nn.ModuleList(blocks_list)
167
+
168
+ self.norm = norm_layer(embed_dim)
169
+ self.head = nn.Identity()
170
+
171
+ self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
172
+
173
+ self.init_weights()
174
+
175
+ def init_weights(self):
176
+ trunc_normal_(self.pos_embed, std=0.02)
177
+ nn.init.normal_(self.cls_token, std=1e-6)
178
+ if self.register_tokens is not None:
179
+ nn.init.normal_(self.register_tokens, std=1e-6)
180
+ named_apply(init_weights_vit_timm, self)
181
+
182
+ def interpolate_pos_encoding(self, x, w, h):
183
+ previous_dtype = x.dtype
184
+ npatch = x.shape[1] - 1
185
+ N = self.pos_embed.shape[1] - 1
186
+ if npatch == N and w == h:
187
+ return self.pos_embed
188
+ pos_embed = self.pos_embed.float()
189
+ class_pos_embed = pos_embed[:, 0]
190
+ patch_pos_embed = pos_embed[:, 1:]
191
+ dim = x.shape[-1]
192
+ w0 = w // self.patch_size
193
+ h0 = h // self.patch_size
194
+ M = int(math.sqrt(N)) # Recover the number of patches in each dimension
195
+ assert N == M * M
196
+ kwargs = {}
197
+ if self.interpolate_offset:
198
+ # Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8
199
+ # Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors
200
+ sx = float(w0 + self.interpolate_offset) / M
201
+ sy = float(h0 + self.interpolate_offset) / M
202
+ kwargs["scale_factor"] = (sx, sy)
203
+ else:
204
+ # Simply specify an output size instead of a scale factor
205
+ kwargs["size"] = (w0, h0)
206
+ patch_pos_embed = nn.functional.interpolate(
207
+ patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2),
208
+ mode="bicubic",
209
+ antialias=self.interpolate_antialias,
210
+ **kwargs,
211
+ )
212
+ assert (w0, h0) == patch_pos_embed.shape[-2:]
213
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
214
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
215
+
216
+ def prepare_tokens_with_masks(self, x, masks=None):
217
+ B, nc, w, h = x.shape
218
+ x = self.patch_embed(x)
219
+ if masks is not None:
220
+ x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
221
+
222
+ x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
223
+ x = x + self.interpolate_pos_encoding(x, w, h)
224
+
225
+ if self.register_tokens is not None:
226
+ x = torch.cat(
227
+ (
228
+ x[:, :1],
229
+ self.register_tokens.expand(x.shape[0], -1, -1),
230
+ x[:, 1:],
231
+ ),
232
+ dim=1,
233
+ )
234
+
235
+ return x
236
+
237
+ def forward_features_list(self, x_list, masks_list):
238
+ x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
239
+ for blk in self.blocks:
240
+ x = blk(x)
241
+
242
+ all_x = x
243
+ output = []
244
+ for x, masks in zip(all_x, masks_list):
245
+ x_norm = self.norm(x)
246
+ output.append(
247
+ {
248
+ "x_norm_clstoken": x_norm[:, 0],
249
+ "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
250
+ "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
251
+ "x_prenorm": x,
252
+ "masks": masks,
253
+ }
254
+ )
255
+ return output
256
+
257
+ def forward_features(self, x, masks=None):
258
+ if isinstance(x, list):
259
+ return self.forward_features_list(x, masks)
260
+
261
+ x = self.prepare_tokens_with_masks(x, masks)
262
+
263
+ for blk in self.blocks:
264
+ x = blk(x)
265
+
266
+ x_norm = self.norm(x)
267
+ return {
268
+ "x_norm_clstoken": x_norm[:, 0],
269
+ "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
270
+ "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
271
+ "x_prenorm": x,
272
+ "masks": masks,
273
+ }
274
+
275
+ def _get_intermediate_layers_not_chunked(self, x, n=1):
276
+ x = self.prepare_tokens_with_masks(x)
277
+ # If n is an int, take the n last blocks. If it's a list, take them
278
+ output, total_block_len = [], len(self.blocks)
279
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
280
+ for i, blk in enumerate(self.blocks):
281
+ x = blk(x)
282
+ if i in blocks_to_take:
283
+ output.append(x)
284
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
285
+ return output
286
+
287
+ def _get_intermediate_layers_chunked(self, x, n=1):
288
+ x = self.prepare_tokens_with_masks(x)
289
+ output, i, total_block_len = [], 0, len(self.blocks[-1])
290
+ # If n is an int, take the n last blocks. If it's a list, take them
291
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
292
+ for block_chunk in self.blocks:
293
+ for blk in block_chunk[i:]: # Passing the nn.Identity()
294
+ x = blk(x)
295
+ if i in blocks_to_take:
296
+ output.append(x)
297
+ i += 1
298
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
299
+ return output
300
+
301
+ def get_intermediate_layers(
302
+ self,
303
+ x: torch.Tensor,
304
+ n: Union[int, Sequence] = 1, # Layers or n last layers to take
305
+ reshape: bool = False,
306
+ return_class_token: bool = False,
307
+ norm=True,
308
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
309
+
310
+ if self.bag_of_channels:
311
+ B, C, H, W = x.shape
312
+ x = x.reshape(B * C, 1, H, W) # passing channels to batch dimension to get encodings for each channel
313
+
314
+ if self.chunked_blocks:
315
+ outputs = self._get_intermediate_layers_chunked(x, n)
316
+ else:
317
+ outputs = self._get_intermediate_layers_not_chunked(x, n)
318
+ if norm:
319
+ outputs = [self.norm(out) for out in outputs]
320
+ class_tokens = [out[:, 0] for out in outputs]
321
+ outputs = [out[:, 1 + self.num_register_tokens :] for out in outputs]
322
+ if reshape:
323
+ B, _, w, h = x.shape
324
+ outputs = [
325
+ out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
326
+ for out in outputs
327
+ ]
328
+
329
+ if self.bag_of_channels:
330
+ output = tuple(zip(outputs, class_tokens))
331
+ output = list(
332
+ zip(*output)
333
+ ) # unzip the tuple: (list of patch_tokens per block, list of class tokens per block)
334
+ patch_tokens_per_block = output[0] # [BLOCK1, BLOCK2, ...] where BLOCK1.shape: B*C, N, D
335
+ cls_tokens_per_block = output[1] # [BLOCK1, BLOCK2, ...] where BLOCK1.shape: B*C, D
336
+ patch_tokens_per_block = [
337
+ patch_tokens.reshape(B, C, patch_tokens.shape[-2], patch_tokens.shape[-1])
338
+ for patch_tokens in patch_tokens_per_block
339
+ ] # [BLOCK1, BLOCK2, ...] where BLOCK1.shape: B, C, N, D
340
+ cls_tokens_per_block = [cls_tokens.reshape(B, -1) for cls_tokens in cls_tokens_per_block]
341
+ output = tuple(zip(patch_tokens_per_block, cls_tokens_per_block))
342
+ return output
343
+
344
+ if return_class_token:
345
+ return tuple(zip(outputs, class_tokens))
346
+ return tuple(outputs)
347
+
348
+ def forward(self, *args, is_training=False, **kwargs):
349
+ ret = self.forward_features(*args, **kwargs)
350
+ if is_training:
351
+ return ret
352
+ else:
353
+ return self.head(ret["x_norm_clstoken"])
354
+
355
+
356
+ def init_weights_vit_timm(module: nn.Module, name: str = ""):
357
+ """ViT weight initialization, original timm impl (for reproducibility)"""
358
+ if isinstance(module, nn.Linear):
359
+ trunc_normal_(module.weight, std=0.02)
360
+ if module.bias is not None:
361
+ nn.init.zeros_(module.bias)
362
+
363
+
364
+ def vit_small(patch_size=16, num_register_tokens=0, in_chans=3, channel_adaptive=False, **kwargs):
365
+ model = DinoVisionTransformer(
366
+ patch_size=patch_size,
367
+ embed_dim=384,
368
+ depth=12,
369
+ num_heads=6,
370
+ mlp_ratio=4,
371
+ block_fn=partial(Block, attn_class=MemEffAttention),
372
+ num_register_tokens=num_register_tokens,
373
+ in_chans=in_chans,
374
+ channel_adaptive=channel_adaptive,
375
+ **kwargs,
376
+ )
377
+ return model
378
+
379
+
380
+ def vit_base(patch_size=16, num_register_tokens=0, in_chans=3, channel_adaptive=False, **kwargs):
381
+ model = DinoVisionTransformer(
382
+ patch_size=patch_size,
383
+ embed_dim=768,
384
+ depth=12,
385
+ num_heads=12,
386
+ mlp_ratio=4,
387
+ block_fn=partial(Block, attn_class=MemEffAttention),
388
+ num_register_tokens=num_register_tokens,
389
+ in_chans=in_chans,
390
+ channel_adaptive=channel_adaptive,
391
+ **kwargs,
392
+ )
393
+ return model
394
+
395
+
396
+ def vit_large(patch_size=16, num_register_tokens=0, in_chans=3, channel_adaptive=False, **kwargs):
397
+ model = DinoVisionTransformer(
398
+ patch_size=patch_size,
399
+ embed_dim=1024,
400
+ depth=24,
401
+ num_heads=16,
402
+ mlp_ratio=4,
403
+ block_fn=partial(Block, attn_class=MemEffAttention),
404
+ num_register_tokens=num_register_tokens,
405
+ in_chans=in_chans,
406
+ channel_adaptive=channel_adaptive,
407
+ **kwargs,
408
+ )
409
+ return model
410
+
411
+
412
+ def vit_giant2(patch_size=16, num_register_tokens=0, in_chans=3, channel_adaptive=False, **kwargs):
413
+ """
414
+ Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
415
+ """
416
+ model = DinoVisionTransformer(
417
+ patch_size=patch_size,
418
+ embed_dim=1536,
419
+ depth=40,
420
+ num_heads=24,
421
+ mlp_ratio=4,
422
+ block_fn=partial(Block, attn_class=MemEffAttention),
423
+ num_register_tokens=num_register_tokens,
424
+ in_chans=in_chans,
425
+ channel_adaptive=channel_adaptive,
426
+ **kwargs,
427
+ )
428
+ return model
torch_hub/facebookresearch_dinov2_main/dinov2/run/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
torch_hub/facebookresearch_dinov2_main/dinov2/run/eval/cell_dino/knn.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the CC-by-NC licence,
4
+ # found in the LICENSE_CELL_DINO_CODE file in the root directory of this source tree.
5
+
6
+ import logging
7
+ import os
8
+ import sys
9
+
10
+ from dinov2.eval.cell_dino.knn import get_args_parser as get_knn_args_parser
11
+ from dinov2.logging import setup_logging
12
+ from dinov2.run.submit import get_args_parser, submit_jobs
13
+
14
+
15
+ logger = logging.getLogger("dinov2")
16
+
17
+
18
+ class Evaluator:
19
+ def __init__(self, args):
20
+ self.args = args
21
+
22
+ def __call__(self):
23
+ from dinov2.eval.cell_dino.knn import main as knn_main
24
+
25
+ self._setup_args()
26
+ knn_main(self.args)
27
+
28
+ def checkpoint(self):
29
+ import submitit
30
+
31
+ logger.info(f"Requeuing {self.args}")
32
+ empty = type(self)(self.args)
33
+ return submitit.helpers.DelayedSubmission(empty)
34
+
35
+ def _setup_args(self):
36
+ import submitit
37
+
38
+ job_env = submitit.JobEnvironment()
39
+ self.args.output_dir = self.args.output_dir.replace("%j", str(job_env.job_id))
40
+ logger.info(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}")
41
+ logger.info(f"Args: {self.args}")
42
+
43
+
44
+ def main():
45
+ description = "Submitit launcher for k-NN Cell-DINO and Channel-Adaptive DINO evaluation"
46
+ knn_args_parser = get_knn_args_parser(add_help=False)
47
+ parents = [knn_args_parser]
48
+ args_parser = get_args_parser(description=description, parents=parents)
49
+ args = args_parser.parse_args()
50
+
51
+ setup_logging()
52
+
53
+ assert os.path.exists(args.config_file), "Configuration file does not exist!"
54
+ submit_jobs(Evaluator, args, name="dinov2:knn Cell-DINO")
55
+ return 0
56
+
57
+
58
+ if __name__ == "__main__":
59
+ sys.exit(main())
torch_hub/facebookresearch_dinov2_main/dinov2/run/eval/cell_dino/linear.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the CC-by-NC licence,
4
+ # found in the LICENSE_CELL_DINO_CODE file in the root directory of this source tree.
5
+
6
+ import logging
7
+ import os
8
+ import sys
9
+
10
+ from dinov2.eval.cell_dino.linear import get_args_parser as get_linear_args_parser
11
+ from dinov2.logging import setup_logging
12
+ from dinov2.run.submit import get_args_parser, submit_jobs
13
+
14
+
15
+ logger = logging.getLogger("dinov2")
16
+
17
+
18
+ class Evaluator:
19
+ def __init__(self, args):
20
+ self.args = args
21
+
22
+ def __call__(self):
23
+ from dinov2.eval.cell_dino.linear import main as linear_main
24
+
25
+ self._setup_args()
26
+ linear_main(self.args)
27
+
28
+ def checkpoint(self):
29
+ import submitit
30
+
31
+ logger.info(f"Requeuing {self.args}")
32
+ empty = type(self)(self.args)
33
+ return submitit.helpers.DelayedSubmission(empty)
34
+
35
+ def _setup_args(self):
36
+ import submitit
37
+
38
+ job_env = submitit.JobEnvironment()
39
+ self.args.output_dir = self.args.output_dir.replace("%j", str(job_env.job_id))
40
+ logger.info(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}")
41
+ logger.info(f"Args: {self.args}")
42
+
43
+
44
+ def main():
45
+ description = "Submitit launcher for DINOv2 linear Cell-DINO and Channel-Adaptive DINO evaluation"
46
+ linear_args_parser = get_linear_args_parser(add_help=False)
47
+ parents = [linear_args_parser]
48
+ args_parser = get_args_parser(description=description, parents=parents)
49
+ args = args_parser.parse_args()
50
+
51
+ setup_logging()
52
+
53
+ assert os.path.exists(args.config_file), "Configuration file does not exist!"
54
+ submit_jobs(Evaluator, args, name="dinov2:linear Cell-DINO")
55
+ return 0
56
+
57
+
58
+ if __name__ == "__main__":
59
+ sys.exit(main())
torch_hub/facebookresearch_dinov2_main/dinov2/run/submit.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import argparse
7
+ import logging
8
+ import os
9
+ from pathlib import Path
10
+ from typing import List, Optional
11
+
12
+ import submitit
13
+
14
+ from dinov2.utils.cluster import (
15
+ get_slurm_executor_parameters,
16
+ get_slurm_partition,
17
+ get_user_checkpoint_path,
18
+ )
19
+
20
+
21
+ logger = logging.getLogger("dinov2")
22
+
23
+
24
+ def get_args_parser(
25
+ description: Optional[str] = None,
26
+ parents: Optional[List[argparse.ArgumentParser]] = None,
27
+ add_help: bool = True,
28
+ ) -> argparse.ArgumentParser:
29
+ parents = parents or []
30
+ slurm_partition = get_slurm_partition()
31
+ parser = argparse.ArgumentParser(
32
+ description=description,
33
+ parents=parents,
34
+ add_help=add_help,
35
+ )
36
+ parser.add_argument(
37
+ "--ngpus",
38
+ "--gpus",
39
+ "--gpus-per-node",
40
+ default=8,
41
+ type=int,
42
+ help="Number of GPUs to request on each node",
43
+ )
44
+ parser.add_argument(
45
+ "--nodes",
46
+ "--nnodes",
47
+ default=1,
48
+ type=int,
49
+ help="Number of nodes to request",
50
+ )
51
+ parser.add_argument(
52
+ "--timeout",
53
+ default=2800,
54
+ type=int,
55
+ help="Duration of the job",
56
+ )
57
+ parser.add_argument(
58
+ "--partition",
59
+ default=slurm_partition,
60
+ type=str,
61
+ help="Partition where to submit",
62
+ )
63
+ parser.add_argument(
64
+ "--use-volta32",
65
+ action="store_true",
66
+ help="Request V100-32GB GPUs",
67
+ )
68
+ parser.add_argument(
69
+ "--comment",
70
+ default="",
71
+ type=str,
72
+ help="Comment to pass to scheduler, e.g. priority message",
73
+ )
74
+ parser.add_argument(
75
+ "--exclude",
76
+ default="",
77
+ type=str,
78
+ help="Nodes to exclude",
79
+ )
80
+ return parser
81
+
82
+
83
+ def get_shared_folder() -> Path:
84
+ user_checkpoint_path = get_user_checkpoint_path()
85
+ if user_checkpoint_path is None:
86
+ raise RuntimeError("Path to user checkpoint cannot be determined")
87
+ path = user_checkpoint_path / "experiments"
88
+ path.mkdir(exist_ok=True)
89
+ return path
90
+
91
+
92
+ def submit_jobs(task_class, args, name: str):
93
+ if not args.output_dir:
94
+ args.output_dir = str(get_shared_folder() / "%j")
95
+
96
+ Path(args.output_dir).mkdir(parents=True, exist_ok=True)
97
+ executor = submitit.AutoExecutor(folder=args.output_dir, slurm_max_num_timeout=30)
98
+
99
+ kwargs = {}
100
+ if args.use_volta32:
101
+ kwargs["slurm_constraint"] = "volta32gb"
102
+ if args.comment:
103
+ kwargs["slurm_comment"] = args.comment
104
+ if args.exclude:
105
+ kwargs["slurm_exclude"] = args.exclude
106
+
107
+ executor_params = get_slurm_executor_parameters(
108
+ nodes=args.nodes,
109
+ num_gpus_per_node=args.ngpus,
110
+ timeout_min=args.timeout, # max is 60 * 72
111
+ slurm_signal_delay_s=120,
112
+ slurm_partition=args.partition,
113
+ **kwargs,
114
+ )
115
+ executor.update_parameters(name=name, **executor_params)
116
+
117
+ task = task_class(args)
118
+ job = executor.submit(task)
119
+
120
+ logger.info(f"Submitted job_id: {job.job_id}")
121
+ str_output_dir = os.path.abspath(args.output_dir).replace("%j", str(job.job_id))
122
+ logger.info(f"Logs and checkpoints will be saved at: {str_output_dir}")
torch_hub/facebookresearch_dinov2_main/dinov2/run/train/train.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import logging
7
+ import os
8
+ import sys
9
+
10
+ from dinov2.logging import setup_logging
11
+ from dinov2.train import get_args_parser as get_train_args_parser
12
+ from dinov2.run.submit import get_args_parser, submit_jobs
13
+
14
+
15
+ logger = logging.getLogger("dinov2")
16
+
17
+
18
+ class Trainer(object):
19
+ def __init__(self, args):
20
+ self.args = args
21
+
22
+ def __call__(self):
23
+ from dinov2.train import main as train_main
24
+
25
+ self._setup_args()
26
+ train_main(self.args)
27
+
28
+ def checkpoint(self):
29
+ import submitit
30
+
31
+ logger.info(f"Requeuing {self.args}")
32
+ empty = type(self)(self.args)
33
+ return submitit.helpers.DelayedSubmission(empty)
34
+
35
+ def _setup_args(self):
36
+ import submitit
37
+
38
+ job_env = submitit.JobEnvironment()
39
+ self.args.output_dir = self.args.output_dir.replace("%j", str(job_env.job_id))
40
+ logger.info(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}")
41
+ logger.info(f"Args: {self.args}")
42
+
43
+
44
+ def main():
45
+ description = "Submitit launcher for DINOv2 training"
46
+ train_args_parser = get_train_args_parser(add_help=False)
47
+ parents = [train_args_parser]
48
+ args_parser = get_args_parser(description=description, parents=parents)
49
+ args = args_parser.parse_args()
50
+
51
+ setup_logging()
52
+
53
+ assert os.path.exists(args.config_file), "Configuration file does not exist!"
54
+ submit_jobs(Trainer, args, name="dinov2:train")
55
+ return 0
56
+
57
+
58
+ if __name__ == "__main__":
59
+ sys.exit(main())
torch_hub/facebookresearch_dinov2_main/dinov2/thirdparty/CLIP/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2021 OpenAI
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
torch_hub/facebookresearch_dinov2_main/dinov2/thirdparty/CLIP/clip/simple_tokenizer.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gzip
2
+ import html
3
+ import os
4
+ from functools import lru_cache
5
+
6
+ import ftfy
7
+ import regex as re
8
+
9
+
10
+ @lru_cache()
11
+ def default_bpe():
12
+ return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
13
+
14
+
15
+ @lru_cache()
16
+ def bytes_to_unicode():
17
+ """
18
+ Returns list of utf-8 byte and a corresponding list of unicode strings.
19
+ The reversible bpe codes work on unicode strings.
20
+ This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
21
+ When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
22
+ This is a signficant percentage of your normal, say, 32K bpe vocab.
23
+ To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
24
+ And avoids mapping to whitespace/control characters the bpe code barfs on.
25
+ """
26
+ bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
27
+ cs = bs[:]
28
+ n = 0
29
+ for b in range(2**8):
30
+ if b not in bs:
31
+ bs.append(b)
32
+ cs.append(2**8 + n)
33
+ n += 1
34
+ cs = [chr(n) for n in cs]
35
+ return dict(zip(bs, cs))
36
+
37
+
38
+ def get_pairs(word):
39
+ """Return set of symbol pairs in a word.
40
+ Word is represented as tuple of symbols (symbols being variable-length strings).
41
+ """
42
+ pairs = set()
43
+ prev_char = word[0]
44
+ for char in word[1:]:
45
+ pairs.add((prev_char, char))
46
+ prev_char = char
47
+ return pairs
48
+
49
+
50
+ def basic_clean(text):
51
+ text = ftfy.fix_text(text)
52
+ text = html.unescape(html.unescape(text))
53
+ return text.strip()
54
+
55
+
56
+ def whitespace_clean(text):
57
+ text = re.sub(r"\s+", " ", text)
58
+ text = text.strip()
59
+ return text
60
+
61
+
62
+ class SimpleTokenizer(object):
63
+ def __init__(self, bpe_path: str = default_bpe()):
64
+ self.byte_encoder = bytes_to_unicode()
65
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
66
+ merges = gzip.open(bpe_path).read().decode("utf-8").split("\n")
67
+ merges = merges[1 : 49152 - 256 - 2 + 1]
68
+ merges = [tuple(merge.split()) for merge in merges]
69
+ vocab = list(bytes_to_unicode().values())
70
+ vocab = vocab + [v + "</w>" for v in vocab]
71
+ for merge in merges:
72
+ vocab.append("".join(merge))
73
+ vocab.extend(["<|startoftext|>", "<|endoftext|>"])
74
+ self.encoder = dict(zip(vocab, range(len(vocab))))
75
+ self.decoder = {v: k for k, v in self.encoder.items()}
76
+ self.bpe_ranks = dict(zip(merges, range(len(merges))))
77
+ self.cache = {"<|startoftext|>": "<|startoftext|>", "<|endoftext|>": "<|endoftext|>"}
78
+ self.pat = re.compile(
79
+ r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
80
+ re.IGNORECASE,
81
+ )
82
+
83
+ def bpe(self, token):
84
+ if token in self.cache:
85
+ return self.cache[token]
86
+ word = tuple(token[:-1]) + (token[-1] + "</w>",)
87
+ pairs = get_pairs(word)
88
+
89
+ if not pairs:
90
+ return token + "</w>"
91
+
92
+ while True:
93
+ bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
94
+ if bigram not in self.bpe_ranks:
95
+ break
96
+ first, second = bigram
97
+ new_word = []
98
+ i = 0
99
+ while i < len(word):
100
+ try:
101
+ j = word.index(first, i)
102
+ new_word.extend(word[i:j])
103
+ i = j
104
+ except Exception:
105
+ new_word.extend(word[i:])
106
+ break
107
+
108
+ if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
109
+ new_word.append(first + second)
110
+ i += 2
111
+ else:
112
+ new_word.append(word[i])
113
+ i += 1
114
+ new_word = tuple(new_word)
115
+ word = new_word
116
+ if len(word) == 1:
117
+ break
118
+ else:
119
+ pairs = get_pairs(word)
120
+ word = " ".join(word)
121
+ self.cache[token] = word
122
+ return word
123
+
124
+ def encode(self, text):
125
+ bpe_tokens = []
126
+ text = whitespace_clean(basic_clean(text)).lower()
127
+ for token in re.findall(self.pat, text):
128
+ token = "".join(self.byte_encoder[b] for b in token.encode("utf-8"))
129
+ bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" "))
130
+ return bpe_tokens
131
+
132
+ def decode(self, tokens):
133
+ text = "".join([self.decoder[token] for token in tokens])
134
+ text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors="replace").replace("</w>", " ")
135
+ return text
torch_hub/facebookresearch_dinov2_main/dinov2/utils/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
torch_hub/facebookresearch_dinov2_main/dinov2/utils/checkpoint.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the CC-by-NC licence,
4
+ # found in the LICENSE_CELL_DINO_CODE file in the root directory of this source tree.
5
+
6
+ from typing import Any
7
+
8
+ from fvcore.common.checkpoint import Checkpointer, PeriodicCheckpointer
9
+ from torch import nn
10
+
11
+ import dinov2.distributed as dist
12
+
13
+
14
+ class PeriodicCheckpointerWithCleanup(PeriodicCheckpointer):
15
+ @property
16
+ def does_write(self) -> bool:
17
+ """See https://github.com/facebookresearch/fvcore/blob/main/fvcore/common/checkpoint.py#L114"""
18
+ return self.checkpointer.save_dir and self.checkpointer.save_to_disk
19
+
20
+ def save_best(self, **kwargs: Any) -> None:
21
+ """Same argument as `Checkpointer.save`, to save a model named like `model_best.pth`"""
22
+ self.checkpointer.save(f"{self.file_prefix}_best", **kwargs)
23
+
24
+ def has_checkpoint(self) -> bool:
25
+ return self.checkpointer.has_checkpoint()
26
+
27
+ def get_checkpoint_file(self) -> str: # returns "" if the file does not exist
28
+ return self.checkpointer.get_checkpoint_file()
29
+
30
+ def load(self, path: str, checkpointables=None) -> dict[str, Any]:
31
+ return self.checkpointer.load(path=path, checkpointables=checkpointables)
32
+
33
+ def step(self, iteration: int, **kwargs: Any) -> None:
34
+ if not self.does_write: # step also removes files, so should be deactivated when object does not write
35
+ return
36
+ super().step(iteration=iteration, **kwargs)
37
+
38
+
39
+ def resume_or_load(checkpointer: Checkpointer, path: str, *, resume: bool = True) -> dict[str, Any]:
40
+ """
41
+ If `resume` is True, this method attempts to resume from the last
42
+ checkpoint, if exists. Otherwise, load checkpoint from the given path.
43
+ Similar to Checkpointer.resume_or_load in fvcore
44
+ https://github.com/facebookresearch/fvcore/blob/main/fvcore/common/checkpoint.py#L208
45
+ but always reload checkpointables, in case we want to resume the training in a new job.
46
+ """
47
+ if resume and checkpointer.has_checkpoint():
48
+ path = checkpointer.get_checkpoint_file()
49
+ return checkpointer.load(path)
50
+
51
+
52
+ def build_periodic_checkpointer(
53
+ model: nn.Module,
54
+ save_dir="",
55
+ *,
56
+ period: int,
57
+ max_iter=None,
58
+ max_to_keep=None,
59
+ **checkpointables: Any,
60
+ ) -> PeriodicCheckpointerWithCleanup:
61
+ """Util to build a `PeriodicCheckpointerWithCleanup`."""
62
+ checkpointer = Checkpointer(model, save_dir, **checkpointables, save_to_disk=dist.is_main_process())
63
+ return PeriodicCheckpointerWithCleanup(checkpointer, period, max_iter=max_iter, max_to_keep=max_to_keep)
torch_hub/facebookresearch_dinov2_main/dinov2/utils/cluster.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ from enum import Enum
7
+ import os
8
+ from pathlib import Path
9
+ from typing import Any, Dict, Optional
10
+
11
+
12
+ class ClusterType(Enum):
13
+ AWS = "aws"
14
+ FAIR = "fair"
15
+ RSC = "rsc"
16
+
17
+
18
+ def _guess_cluster_type() -> ClusterType:
19
+ uname = os.uname()
20
+ if uname.sysname == "Linux":
21
+ if uname.release.endswith("-aws"):
22
+ # Linux kernel versions on AWS instances are of the form "5.4.0-1051-aws"
23
+ return ClusterType.AWS
24
+ elif uname.nodename.startswith("rsc"):
25
+ # Linux kernel versions on RSC instances are standard ones but hostnames start with "rsc"
26
+ return ClusterType.RSC
27
+
28
+ return ClusterType.FAIR
29
+
30
+
31
+ def get_cluster_type(cluster_type: Optional[ClusterType] = None) -> Optional[ClusterType]:
32
+ if cluster_type is None:
33
+ return _guess_cluster_type()
34
+
35
+ return cluster_type
36
+
37
+
38
+ def get_checkpoint_path(cluster_type: Optional[ClusterType] = None) -> Optional[Path]:
39
+ cluster_type = get_cluster_type(cluster_type)
40
+ if cluster_type is None:
41
+ return None
42
+
43
+ CHECKPOINT_DIRNAMES = {
44
+ ClusterType.AWS: "checkpoints",
45
+ ClusterType.FAIR: "checkpoint",
46
+ ClusterType.RSC: "checkpoint/dino",
47
+ }
48
+ return Path("/") / CHECKPOINT_DIRNAMES[cluster_type]
49
+
50
+
51
+ def get_user_checkpoint_path(cluster_type: Optional[ClusterType] = None) -> Optional[Path]:
52
+ checkpoint_path = get_checkpoint_path(cluster_type)
53
+ if checkpoint_path is None:
54
+ return None
55
+
56
+ username = os.environ.get("USER")
57
+ assert username is not None
58
+ return checkpoint_path / username
59
+
60
+
61
+ def get_slurm_partition(cluster_type: Optional[ClusterType] = None) -> Optional[str]:
62
+ cluster_type = get_cluster_type(cluster_type)
63
+ if cluster_type is None:
64
+ return None
65
+
66
+ SLURM_PARTITIONS = {
67
+ ClusterType.AWS: "learnaccel",
68
+ ClusterType.FAIR: "learnaccel",
69
+ ClusterType.RSC: "learn",
70
+ }
71
+ return SLURM_PARTITIONS[cluster_type]
72
+
73
+
74
+ def get_slurm_executor_parameters(
75
+ nodes: int, num_gpus_per_node: int, cluster_type: Optional[ClusterType] = None, **kwargs
76
+ ) -> Dict[str, Any]:
77
+ # create default parameters
78
+ params = {
79
+ "mem_gb": 0, # Requests all memory on a node, see https://slurm.schedmd.com/sbatch.html
80
+ "gpus_per_node": num_gpus_per_node,
81
+ "tasks_per_node": num_gpus_per_node, # one task per GPU
82
+ "cpus_per_task": 10,
83
+ "nodes": nodes,
84
+ "slurm_partition": get_slurm_partition(cluster_type),
85
+ }
86
+ # apply cluster-specific adjustments
87
+ cluster_type = get_cluster_type(cluster_type)
88
+ if cluster_type == ClusterType.AWS:
89
+ params["cpus_per_task"] = 12
90
+ del params["mem_gb"]
91
+ elif cluster_type == ClusterType.RSC:
92
+ params["cpus_per_task"] = 12
93
+ # set additional parameters / apply overrides
94
+ params.update(kwargs)
95
+ return params
torch_hub/facebookresearch_dinov2_main/dinov2/utils/config.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import math
7
+ import logging
8
+ import os
9
+
10
+ from omegaconf import OmegaConf
11
+
12
+ import dinov2.distributed as distributed
13
+ from dinov2.logging import setup_logging
14
+ from dinov2.utils import utils
15
+ from dinov2.configs import dinov2_default_config
16
+
17
+
18
+ logger = logging.getLogger("dinov2")
19
+
20
+
21
+ def apply_scaling_rules_to_cfg(cfg): # to fix
22
+ if cfg.optim.scaling_rule == "sqrt_wrt_1024":
23
+ base_lr = cfg.optim.base_lr
24
+ cfg.optim.lr = base_lr
25
+ cfg.optim.lr *= math.sqrt(cfg.train.batch_size_per_gpu * distributed.get_global_size() / 1024.0)
26
+ logger.info(f"sqrt scaling learning rate; base: {base_lr}, new: {cfg.optim.lr}")
27
+ else:
28
+ raise NotImplementedError
29
+ return cfg
30
+
31
+
32
+ def write_config(cfg, output_dir, name="config.yaml"):
33
+ logger.info(OmegaConf.to_yaml(cfg))
34
+ saved_cfg_path = os.path.join(output_dir, name)
35
+ with open(saved_cfg_path, "w") as f:
36
+ OmegaConf.save(config=cfg, f=f)
37
+ return saved_cfg_path
38
+
39
+
40
+ def get_cfg_from_args(args):
41
+ args.output_dir = os.path.abspath(args.output_dir)
42
+ args.opts += [f"train.output_dir={args.output_dir}"]
43
+ default_cfg = OmegaConf.create(dinov2_default_config)
44
+ cfg = OmegaConf.load(args.config_file)
45
+ cfg = OmegaConf.merge(default_cfg, cfg, OmegaConf.from_cli(args.opts))
46
+ return cfg
47
+
48
+
49
+ def default_setup(args):
50
+ distributed.enable(overwrite=True)
51
+ seed = getattr(args, "seed", 0)
52
+ rank = distributed.get_global_rank()
53
+
54
+ global logger
55
+ setup_logging(output=args.output_dir, level=logging.INFO)
56
+ logger = logging.getLogger("dinov2")
57
+
58
+ utils.fix_random_seeds(seed + rank)
59
+ logger.info("git:\n {}\n".format(utils.get_sha()))
60
+ logger.info("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items())))
61
+
62
+
63
+ def setup(args):
64
+ """
65
+ Create configs and perform basic setups.
66
+ """
67
+ cfg = get_cfg_from_args(args)
68
+ os.makedirs(args.output_dir, exist_ok=True)
69
+ default_setup(args)
70
+ apply_scaling_rules_to_cfg(cfg)
71
+ write_config(cfg, args.output_dir)
72
+ return cfg
torch_hub/facebookresearch_dinov2_main/dinov2/utils/param_groups.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ from collections import defaultdict
7
+ import logging
8
+
9
+
10
+ logger = logging.getLogger("dinov2")
11
+
12
+
13
+ def get_vit_lr_decay_rate(name, lr_decay_rate=1.0, num_layers=12, force_is_backbone=False, chunked_blocks=False):
14
+ """
15
+ Calculate lr decay rate for different ViT blocks.
16
+ Args:
17
+ name (string): parameter name.
18
+ lr_decay_rate (float): base lr decay rate.
19
+ num_layers (int): number of ViT blocks.
20
+ Returns:
21
+ lr decay rate for the given parameter.
22
+ """
23
+ layer_id = num_layers + 1
24
+ if name.startswith("backbone") or force_is_backbone:
25
+ if (
26
+ ".pos_embed" in name
27
+ or ".patch_embed" in name
28
+ or ".mask_token" in name
29
+ or ".cls_token" in name
30
+ or ".register_tokens" in name
31
+ ):
32
+ layer_id = 0
33
+ elif force_is_backbone and (
34
+ "pos_embed" in name
35
+ or "patch_embed" in name
36
+ or "mask_token" in name
37
+ or "cls_token" in name
38
+ or "register_tokens" in name
39
+ ):
40
+ layer_id = 0
41
+ elif ".blocks." in name and ".residual." not in name:
42
+ layer_id = int(name[name.find(".blocks.") :].split(".")[2]) + 1
43
+ elif chunked_blocks and "blocks." in name and "residual." not in name:
44
+ layer_id = int(name[name.find("blocks.") :].split(".")[2]) + 1
45
+ elif "blocks." in name and "residual." not in name:
46
+ layer_id = int(name[name.find("blocks.") :].split(".")[1]) + 1
47
+
48
+ return lr_decay_rate ** (num_layers + 1 - layer_id)
49
+
50
+
51
+ def get_params_groups_with_decay(model, lr_decay_rate=1.0, patch_embed_lr_mult=1.0):
52
+ chunked_blocks = False
53
+ if hasattr(model, "n_blocks"):
54
+ logger.info("chunked fsdp")
55
+ n_blocks = model.n_blocks
56
+ chunked_blocks = model.chunked_blocks
57
+ elif hasattr(model, "blocks"):
58
+ logger.info("first code branch")
59
+ n_blocks = len(model.blocks)
60
+ elif hasattr(model, "backbone"):
61
+ logger.info("second code branch")
62
+ n_blocks = len(model.backbone.blocks)
63
+ else:
64
+ logger.info("else code branch")
65
+ n_blocks = 0
66
+ all_param_groups = []
67
+
68
+ for name, param in model.named_parameters():
69
+ name = name.replace("_fsdp_wrapped_module.", "")
70
+ if not param.requires_grad:
71
+ continue
72
+ decay_rate = get_vit_lr_decay_rate(
73
+ name, lr_decay_rate, num_layers=n_blocks, force_is_backbone=n_blocks > 0, chunked_blocks=chunked_blocks
74
+ )
75
+ d = {"params": param, "is_last_layer": False, "lr_multiplier": decay_rate, "wd_multiplier": 1.0, "name": name}
76
+
77
+ if "last_layer" in name:
78
+ d.update({"is_last_layer": True})
79
+
80
+ if name.endswith(".bias") or "norm" in name or "gamma" in name:
81
+ d.update({"wd_multiplier": 0.0})
82
+
83
+ if "patch_embed" in name:
84
+ d.update({"lr_multiplier": d["lr_multiplier"] * patch_embed_lr_mult})
85
+
86
+ all_param_groups.append(d)
87
+ logger.info(f"""{name}: lr_multiplier: {d["lr_multiplier"]}, wd_multiplier: {d["wd_multiplier"]}""")
88
+
89
+ return all_param_groups
90
+
91
+
92
+ def fuse_params_groups(all_params_groups, keys=("lr_multiplier", "wd_multiplier", "is_last_layer")):
93
+ fused_params_groups = defaultdict(lambda: {"params": []})
94
+ for d in all_params_groups:
95
+ identifier = ""
96
+ for k in keys:
97
+ identifier += k + str(d[k]) + "_"
98
+
99
+ for k in keys:
100
+ fused_params_groups[identifier][k] = d[k]
101
+ fused_params_groups[identifier]["params"].append(d["params"])
102
+
103
+ return fused_params_groups.values()
torch_hub/facebookresearch_dinov2_main/docs/README_CHANNEL_ADAPTIVE_DINO.md ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # Scaling Channel-Adaptive Self-Supervised Learning
3
+
4
+ [[`Paper `](https://openreview.net/forum?id=pT8sgtRVAf))] [[`BibTeX`](#citing-channeladaptivedino-and-dinov2)]
5
+
6
+ **[Meta AI Research, FAIR](https://ai.facebook.com/research/)**
7
+
8
+ Alice V. De Lorenci, Seungeun Yi, Théo Moutakanni, Piotr Bojanowski, Camille Couprie, Juan C. Caicedo, Wolfgang M. Pernice,
9
+
10
+ with special thanks to Elouan Gardes for his contributions to the codebase.
11
+
12
+ PyTorch implementation and pretrained model for ChannelAdaptive-DINO.
13
+
14
+ The contents of this repo, including the code and model weights, are intended for research use only. It is not for use in medical procedures, including any diagnostics, treatment, or curative applications. Do not use this model for any clinical purpose or as a substitute for professional medical judgement.
15
+
16
+ ![teaser](ChannelAdaptiveDINO.png)
17
+
18
+ ## Pretrained model
19
+
20
+ You can download the model weights trained on the Extended CHAMMI dataset (combination of five cell microscopy datasets with variable numbers of channels) on torchhub.
21
+
22
+ ## Installation
23
+
24
+ Follow instructions in the DINOv2 README. There are two additionnal dependencies to pandas and tifffile.
25
+
26
+ ## What is included / not included
27
+
28
+ This repository includes the Bag of Channel implementation, not the Hierarchical attention approach.
29
+
30
+ ## Data preparation
31
+
32
+ The CHAMMI dataset is available [here](https://github.com/chaudatascience/channel_adaptive_models).
33
+
34
+ The HPA-FoV dataset is available [here](https://www.ebi.ac.uk/biostudies/bioimages/studies/S-BIAD2443)
35
+
36
+ Content: a directory new_512_whole_images and two csv files:
37
+
38
+ "whole_images_512_test.csv"
39
+
40
+ "whole_images_512_train.csv"
41
+
42
+
43
+ :warning: To execute the commands provided in the next sections for training and evaluation, the `dinov2` package should be included in the Python module search path, i.e. simply prefix the command to run with `PYTHONPATH=.`.
44
+
45
+ ## Training
46
+
47
+ ### Fast setup: training Channel-Adaptive DINO ViT-L/16 on HPA single cell dataset
48
+
49
+ Run Channel-Adaptive DINO training on 4 A100-80GB nodes (32 GPUs) in a SLURM cluster environment with submitit:
50
+
51
+ ```shell
52
+ python dinov2/run/train/train.py \
53
+ --nodes 4 \
54
+ --config-file dinov2/configs/train/cell_dino/vitl16_boc_hpafov.yaml \
55
+ --output-dir <PATH/TO/OUTPUT/DIR> \
56
+ train.dataset_path=HPAFoV:split=TRAIN:root=<PATH/TO/DATASET>:wildcard=SEPARATE_CHANNELS"
57
+ ```
58
+
59
+ Training time is approximately 2 days.
60
+ The training code saves the weights of the teacher in the `eval` folder every 12500 iterations for evaluation.
61
+ This example only performs pretraining on the HPA-FoV dataset.
62
+
63
+ ## Evaluation
64
+
65
+ The training code regularly saves the teacher weights. In order to evaluate the model, run the following evaluation on a single node:
66
+
67
+ ### Linear Evaluation on HPAFoV
68
+
69
+ ```shell
70
+ PYTHONPATH=.:dinov2/data python dinov2/run/eval/cell_dino/linear.py \
71
+ --config-file dinov2/configs/eval/cell_dino/vitl16_channel_adaptive_pretrain.yaml \
72
+ --pretrained-weights <PATH/TO/OUTPUT/DIR>/eval/training_359999/teacher_checkpoint.pth \
73
+ --output-dir <PATH/TO/OUTPUT/DIR>/eval/training_359999/linear \
74
+ --train-dataset HPAFoV:split=TRAIN:mode=PROTEIN_LOCALIZATION:root=<PATH/TO/DATASET> \
75
+ --val-dataset HPAFoV:split=VAL:mode=PROTEIN_LOCALIZATION:root=<PATH/TO/DATASET> \
76
+ --val-metric-type mean_per_class_multilabel_f1 \
77
+ --loss-type binary_cross_entropy \
78
+ --bag-of-channels \
79
+ --crop-size 384 \
80
+ --n-last-blocks 4 \
81
+ --batch-size 32 \
82
+ --epoch-length 145 \
83
+ --epochs 30 \
84
+ --avgpool \
85
+ ```
86
+
87
+ ### KNN classification on CHAMMI
88
+
89
+ Go to the docs directory, modifify some paths in launcher_knn_eval_on_chammi.sh and run
90
+
91
+ ```shell
92
+ ./launcher_knn_eval_on_chammi.sh WTC TASK_ONE ;
93
+ ./launcher_knn_eval_on_chammi.sh WTC TASK_TWO ;
94
+ ./launcher_knn_eval_on_chammi.sh HPA TASK_ONE ;
95
+ ./launcher_knn_eval_on_chammi.sh HPA TASK_TWO ;
96
+ ./launcher_knn_eval_on_chammi.sh HPA TASK_THREE ;
97
+ ./launcher_knn_eval_on_chammi.sh CP TASK_ONE ;
98
+ ./launcher_knn_eval_on_chammi.sh CP TASK_TWO ;
99
+ ./launcher_knn_eval_on_chammi.sh CP TASK_THREE ;
100
+ ./launcher_knn_eval_on_chammi.sh CP TASK_FOUR ;
101
+ ```
102
+
103
+ ### Linear classification on CHAMMI
104
+
105
+ Go to the docs directory, modifify some paths in launcher_CHAMMI_eval.sh and run
106
+
107
+ ```shell
108
+ ./launcher_CHAMMI_eval.sh WTC TASK_ONE ;
109
+ ./launcher_CHAMMI_eval.sh WTC TASK_TWO ;
110
+ ./launcher_CHAMMI_eval.sh HPA TASK_ONE ;
111
+ ./launcher_CHAMMI_eval.sh HPA TASK_TWO ;
112
+ ./launcher_CHAMMI_eval.sh HPA TASK_THREE ;
113
+ ./launcher_CHAMMI_eval.sh CP TASK_ONE ;
114
+ ./launcher_CHAMMI_eval.sh CP TASK_TWO ;
115
+ ./launcher_CHAMMI_eval.sh CP TASK_THREE ;
116
+ ./launcher_CHAMMI_eval.sh CP TASK_FOUR ;
117
+ ```
118
+
119
+ | | WTC - Task 1 | WTC - Task 2 | HPA - Task 1 | HPA - Task 2 | HPA - Task 3 | CP - Task 1 | CP - Task 2 | CP - Task 3 | CP - Task 4 |
120
+ | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- |
121
+ | knn reproduced | 80.3 | 79.3 | 91.6 | 61.4 | 29.0 | 89.8 | 57.6 | 23.4 | 18.4 |
122
+ | knn paper | 79.4 | 79.0 | 86.6 | 59.3 | 29.6 | 92.6 | 57.6 | 22.1 | 18.5 |
123
+ | Linear reproduced | 89.9 | 87.9 | 92.7 | 87.2 | 66.2 | 89.9 | 59.8 | 26.6 | 32.5|
124
+ | Linear paper | 90.5 | 89.2 | 88.3 | 84.7 | 65.0 | 90.5 | 60.5 | 25.8 | 32.7|
125
+
126
+
127
+ ## License
128
+
129
+ Cell-DINO code is released under the CC by NC licence See [LICENSE_CELL_DINO_CODE](LICENSE_CELL_DINO_CODE) for additional details.
130
+ Model weights will be released under the FAIR Non-Commercial Research License. See [LICENSE_CELL_DINO_MODELS](LICENSE_CELL_DINO_MODELS) for additional details.
131
+
132
+ ## Contributing
133
+
134
+ See [contributing](CONTRIBUTING.md) and the [code of conduct](CODE_OF_CONDUCT.md).
135
+
136
+ ## Citing ChannelAdaptiveDINO and DINOv2
137
+
138
+ If you find this repository useful, please consider giving a star :star: and citation :t-rex::
139
+
140
+ ```
141
+ @misc{Delorenci2025scaling,
142
+ title={Scaling Channel-Adaptive Self-Supervised Learning},
143
+ author={V. De Lorenci, Alice and Yi, Seungeun and Moutakanni, Theo and Bojanowski, Piotr and Couprie, Camille and Caicedo, Juan C. and Pernice, Wolfgang M.},
144
+ journal={TMLR},
145
+ year={2025}
146
+ }
147
+ ```
148
+
149
+ ```
150
+ @misc{oquab2023dinov2,
151
+ title={DINOv2: Learning Robust Visual Features without Supervision},
152
+ author={Oquab, Maxime and Darcet, Timothée and Moutakanni, Theo and Vo, Huy V. and Szafraniec, Marc and Khalidov, Vasil and Fernandez, Pierre and Haziza, Daniel and Massa, Francisco and El-Nouby, Alaaeldin and Howes, Russell and Huang, Po-Yao and Xu, Hu and Sharma, Vasu and Li, Shang-Wen and Galuba, Wojciech and Rabbat, Mike and Assran, Mido and Ballas, Nicolas and Synnaeve, Gabriel and Misra, Ishan and Jegou, Herve and Mairal, Julien and Labatut, Patrick and Joulin, Armand and Bojanowski, Piotr},
153
+ journal={arXiv:2304.07193},
154
+ year={2023}
155
+ }
156
+ ```
torch_hub/facebookresearch_dinov2_main/notebooks/cell_dino/inference.ipynb ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "6c3f1fe9-af40-4a57-aff0-99313d722f34",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "# Copyright (c) Meta Platforms, Inc. and affiliates.\n",
11
+ "# This source code is licensed under the CC-by-NC licence,\n",
12
+ "# found in the LICENSE_CELL_DINO_CODE file in the root directory of this source tree."
13
+ ]
14
+ },
15
+ {
16
+ "cell_type": "code",
17
+ "execution_count": null,
18
+ "id": "bfe8d7c4-995c-44b4-a8d1-97be2badce8c",
19
+ "metadata": {},
20
+ "outputs": [],
21
+ "source": [
22
+ "SAMPLE_IMAGES_DIR = \"sample_images/\" # path to directory with cell images.\n",
23
+ "REPO_DIR=\"\" # path to the dinov2 repo.\n",
24
+ "# Also define the urls of the pretrained models CPurl, SCurl, FOVurl used in the next cell. Instructions to get the models urls are in the README.md."
25
+ ]
26
+ },
27
+ {
28
+ "cell_type": "code",
29
+ "execution_count": null,
30
+ "id": "70b2a71b-8fb2-4ec3-8cd1-21d6f770f704",
31
+ "metadata": {},
32
+ "outputs": [],
33
+ "source": [
34
+ "import torch\n",
35
+ "cell_dino_vits8 = torch.hub.load(REPO_DIR, 'cell_dino_cp_vits8', source='local', pretrained_url=CPurl)\n",
36
+ "cell_dino_vitl16_sc = torch.hub.load(REPO_DIR, 'cell_dino_hpa_vitl16', source='local', pretrained_url=SCurl)\n",
37
+ "cell_dino_vitl16_fov = torch.hub.load(REPO_DIR, 'cell_dino_hpa_vitl16', source='local', pretrained_url=FOVurl)\n",
38
+ "#channel_adaptive_dino_vitl16 = torch.hub.load(REPO_DIR, 'channel_adaptive_dino_vitl16', source='local', pretrained_url=CAurl)\n",
39
+ "# cell_dino_vitl14 = torch.hub.load(REPO_DIR, 'cell_dino_hpa_vitl14', source='local', pretrained_url=HRurl)"
40
+ ]
41
+ },
42
+ {
43
+ "cell_type": "code",
44
+ "execution_count": null,
45
+ "id": "a6899c21-c4c2-43d4-8cdc-f6cce2c3686b",
46
+ "metadata": {},
47
+ "outputs": [],
48
+ "source": [
49
+ "import torch\n",
50
+ "import torchvision\n",
51
+ "from dinov2.hub.cell_dino.backbones import cell_dino_hpa_vitl16, cell_dino_cp_vits8\n",
52
+ "from functools import partial\n",
53
+ "from dinov2.eval.utils import ModelWithIntermediateLayers\n",
54
+ "\n",
55
+ "DEVICE = \"cuda:0\"\n",
56
+ "\n",
57
+ "class self_normalize(object):\n",
58
+ " def __call__(self, x):\n",
59
+ " x = x / 255\n",
60
+ " m = x.mean((-2, -1), keepdim=True)\n",
61
+ " s = x.std((-2, -1), unbiased=False, keepdim=True)\n",
62
+ " x -= m\n",
63
+ " x /= s + 1e-7\n",
64
+ " return x\n",
65
+ "\n",
66
+ "normalize = self_normalize()"
67
+ ]
68
+ },
69
+ {
70
+ "cell_type": "code",
71
+ "execution_count": null,
72
+ "id": "779ecbb7-3247-4b9e-9248-25c8cbd74d24",
73
+ "metadata": {},
74
+ "outputs": [],
75
+ "source": [
76
+ "# ---------------------- Example inference on HPA-FoV dataset --------------------------\n",
77
+ "\n",
78
+ "# 1- Read one human protein atlas HPA-FoV image (4 channels)\n",
79
+ "img = torchvision.io.read_image(SAMPLE_IMAGES_DIR + \"HPA_FoV_00070df0-bbc3-11e8-b2bc-ac1f6b6435d0.png\")\n",
80
+ "\n",
81
+ "# 2- Normalise image as it was done for training\n",
82
+ "img_hpa_fov = img.unsqueeze(0).to(device=DEVICE)\n",
83
+ "img_hpa_fov = normalize(img_hpa_fov)\n",
84
+ "\n",
85
+ "# 3- Load model\n",
86
+ "cell_dino_model = cell_dino_vitl16_fov\n",
87
+ "cell_dino_model.to(device=DEVICE)\n",
88
+ "cell_dino_model.eval()\n",
89
+ "\n",
90
+ "# 4- Inference\n",
91
+ "features = cell_dino_model(img_hpa_fov)\n",
92
+ "print(features)\n",
93
+ "\n",
94
+ "# 5- [Optional] feature extractor as used for linear evaluation\n",
95
+ "autocast_ctx = partial(torch.cuda.amp.autocast, enabled=True, dtype=torch.float)\n",
96
+ "model_with_interm_layers = ModelWithIntermediateLayers(cell_dino_model, 4, autocast_ctx)\n",
97
+ "features_with_interm_layers = model_with_interm_layers(img_hpa_fov)"
98
+ ]
99
+ },
100
+ {
101
+ "cell_type": "code",
102
+ "execution_count": null,
103
+ "id": "85e67e67-5824-4135-8ca2-f8396cf95cb2",
104
+ "metadata": {},
105
+ "outputs": [],
106
+ "source": [
107
+ "# ---------------------- Example inference on cell painting data --------------------------\n",
108
+ "\n",
109
+ "# 1- Read one cell painting image (5 channels)\n",
110
+ "img = torchvision.io.read_image(SAMPLE_IMAGES_DIR + \"CP_BBBC036_24277_a06_1_976@140x149.png\")\n",
111
+ "img5_channels = torch.zeros([1, 5, 160, 160])\n",
112
+ "for c in range(5):\n",
113
+ " img5_channels[0, c] = img[0, :, 160 * c : 160 * (c + 1)]\n",
114
+ "img5_channels = img5_channels.to(device=DEVICE)\n",
115
+ "\n",
116
+ "# 2- Normalise image as it was done for training\n",
117
+ "img5_channels = normalize(img5_channels)\n",
118
+ "\n",
119
+ "# 3- Load model\n",
120
+ "cell_dino_model = cell_dino_vits8\n",
121
+ "cell_dino_model.to(device=DEVICE)\n",
122
+ "cell_dino_model.eval()\n",
123
+ "\n",
124
+ "# 4- Inference\n",
125
+ "features = cell_dino_model(img5_channels)\n",
126
+ "print(features[0,0:10])"
127
+ ]
128
+ },
129
+ {
130
+ "cell_type": "code",
131
+ "execution_count": null,
132
+ "id": "ada26fcc-fd27-4dbf-983c-bbe06fe04b6f",
133
+ "metadata": {},
134
+ "outputs": [],
135
+ "source": [
136
+ "# ---------------------- Example inference on HPA single cell dataset --------------------------\n",
137
+ "\n",
138
+ "# Read one human protein atlas HPA single cell image (4 channels)\n",
139
+ "img = torchvision.io.read_image(SAMPLE_IMAGES_DIR + \"HPA_single_cell_00285ce4-bba0-11e8-b2b9-ac1f6b6435d0_15.png\")\n",
140
+ "\n",
141
+ "# 2- Normalise image as it was done for training\n",
142
+ "img_hpa = img.unsqueeze(0).to(device=DEVICE)\n",
143
+ "img_hpa = normalize(img_hpa)\n",
144
+ "\n",
145
+ "# 3- Load model\n",
146
+ "cell_dino_model = cell_dino_vitl16_sc\n",
147
+ "cell_dino_model.to(device=DEVICE)\n",
148
+ "cell_dino_model.eval()\n",
149
+ "\n",
150
+ "# 4- Inference\n",
151
+ "features = cell_dino_model(img_hpa)\n",
152
+ "print(features)\n",
153
+ "\n",
154
+ "torch.save(features.cpu(), \"sample_features_hpa.pt\")"
155
+ ]
156
+ }
157
+ ],
158
+ "metadata": {
159
+ "kernelspec": {
160
+ "display_name": "Python (mypy310env)",
161
+ "language": "python",
162
+ "name": "mypy310env"
163
+ },
164
+ "language_info": {
165
+ "codemirror_mode": {
166
+ "name": "ipython",
167
+ "version": 3
168
+ },
169
+ "file_extension": ".py",
170
+ "mimetype": "text/x-python",
171
+ "name": "python",
172
+ "nbconvert_exporter": "python",
173
+ "pygments_lexer": "ipython3",
174
+ "version": "3.10.19"
175
+ }
176
+ },
177
+ "nbformat": 4,
178
+ "nbformat_minor": 5
179
+ }
torch_hub/facebookresearch_dinov2_main/notebooks/depth_estimation.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
torch_hub/facebookresearch_dinov2_main/notebooks/semantic_segmentation.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
torch_hub/facebookresearch_dinov2_main/scripts/cell_dino/launcher_knn_eval_on_chammi.sh ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # 1 : modify CHANNEL_AGNOSTIC_CELL_MODEL, CHAMMI_DATA_PATH and OUTPUT_DIR below
4
+ # 2 : call this script with the two arguments specified below
5
+
6
+ #Arguments:
7
+ # $1 : dataset, e.g CP
8
+ # $2 : task number, e.g TASK_TWO
9
+
10
+ CHAMMI_DATA_PATH=""
11
+ CHANNEL_AGNOSTIC_CELL_MODEL="path_to_model/model.pth"
12
+ OUTPUT_DIR=YOUR_OUTPUT_PATH_$1_$2
13
+
14
+ if [ "$2" == "TASK_FOUR" ]; then
15
+ OTHER_ARG="--leave-one-out-dataset $CHAMMI_DATA_PATH/CP/enriched_meta.csv "
16
+ elif [ "$1" == "HPA" -a "$2" == "TASK_THREE" ]; then
17
+ OTHER_ARG="--leave-one-out-dataset $CHAMMI_DATA_PATH/CHAMMI/HPA/enriched_meta.csv "
18
+ else
19
+ OTHER_ARG=""
20
+ fi
21
+ echo $OTHER_ARG
22
+
23
+ PYTHONPATH=..:../../dinov2/data python ../../dinov2/run/eval/cell_dino/knn.py \
24
+ --config-file ../../dinov2/configs/eval/cell_dino/vitl16_channel_adaptive_pretrain.yaml \
25
+ --pretrained-weights $CHANNEL_AGNOSTIC_CELL_MODEL \
26
+ --output-dir $OUTPUT_DIR \
27
+ --train-dataset CHAMMI_$1:split=TRAIN:root=$CHAMMI_DATA_PATH \
28
+ --val-dataset CHAMMI_$1:split=$2:root=$CHAMMI_DATA_PATH \
29
+ --metric-type mean_per_class_multiclass_f1 \
30
+ --crop-size 224 \
31
+ --batch-size 32 \
32
+ --resize-size 256 \
33
+ --bag-of-channels \
34
+ $OTHER_ARG \
torch_hub/trusted_list ADDED
File without changes