AndrewKof commited on
Commit
7179b2d
·
0 Parent(s):

Initial commit for NEMO tools FastAPI Space

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .DS_Store +0 -0
  2. README.md +0 -0
  3. app/.DS_Store +0 -0
  4. app/__pycache__/main.cpython-310.pyc +0 -0
  5. app/__pycache__/model.cpython-310.pyc +0 -0
  6. app/dinov2/.DS_Store +0 -0
  7. app/dinov2/__init__.py +6 -0
  8. app/dinov2/__pycache__/__init__.cpython-310.pyc +0 -0
  9. app/dinov2/configs/.DS_Store +0 -0
  10. app/dinov2/configs/__init__.py +22 -0
  11. app/dinov2/configs/eval/vitb14_pretrain.yaml +6 -0
  12. app/dinov2/configs/eval/vitb14_reg4_pretrain.yaml +9 -0
  13. app/dinov2/configs/eval/vitg14_pretrain.yaml +7 -0
  14. app/dinov2/configs/eval/vitg14_reg4_pretrain.yaml +10 -0
  15. app/dinov2/configs/eval/vitl14_pretrain.yaml +6 -0
  16. app/dinov2/configs/eval/vitl14_reg4_pretrain.yaml +9 -0
  17. app/dinov2/configs/eval/vits14_pretrain.yaml +6 -0
  18. app/dinov2/configs/eval/vits14_reg4_pretrain.yaml +9 -0
  19. app/dinov2/configs/ssl_default_config.yaml +118 -0
  20. app/dinov2/configs/train/vitg14.yaml +26 -0
  21. app/dinov2/configs/train/vitl14.yaml +26 -0
  22. app/dinov2/configs/train/vitl16_short.yaml +6 -0
  23. app/dinov2/data/__init__.py +10 -0
  24. app/dinov2/data/adapters.py +28 -0
  25. app/dinov2/data/augmentations.py +118 -0
  26. app/dinov2/data/collate.py +49 -0
  27. app/dinov2/data/datasets/__init__.py +7 -0
  28. app/dinov2/data/datasets/decoders.py +31 -0
  29. app/dinov2/data/datasets/extended.py +38 -0
  30. app/dinov2/data/datasets/image_net.py +290 -0
  31. app/dinov2/data/datasets/image_net_22k.py +302 -0
  32. app/dinov2/data/loaders.py +222 -0
  33. app/dinov2/data/masking.py +86 -0
  34. app/dinov2/data/samplers.py +229 -0
  35. app/dinov2/data/transforms.py +91 -0
  36. app/dinov2/distributed/__init__.py +270 -0
  37. app/dinov2/eval/.DS_Store +0 -0
  38. app/dinov2/eval/__init__.py +4 -0
  39. app/dinov2/eval/depth/__init__.py +4 -0
  40. app/dinov2/eval/depth/models/__init__.py +10 -0
  41. app/dinov2/eval/depth/models/backbones/__init__.py +6 -0
  42. app/dinov2/eval/depth/models/backbones/vision_transformer.py +16 -0
  43. app/dinov2/eval/depth/models/builder.py +49 -0
  44. app/dinov2/eval/depth/models/decode_heads/__init__.py +7 -0
  45. app/dinov2/eval/depth/models/decode_heads/decode_head.py +225 -0
  46. app/dinov2/eval/depth/models/decode_heads/dpt_head.py +270 -0
  47. app/dinov2/eval/depth/models/decode_heads/linear_head.py +89 -0
  48. app/dinov2/eval/depth/models/depther/__init__.py +7 -0
  49. app/dinov2/eval/depth/models/depther/base.py +194 -0
  50. app/dinov2/eval/depth/models/depther/encoder_decoder.py +236 -0
.DS_Store ADDED
Binary file (6.15 kB). View file
 
README.md ADDED
File without changes
app/.DS_Store ADDED
Binary file (6.15 kB). View file
 
app/__pycache__/main.cpython-310.pyc ADDED
Binary file (1.27 kB). View file
 
app/__pycache__/model.cpython-310.pyc ADDED
Binary file (3.16 kB). View file
 
app/dinov2/.DS_Store ADDED
Binary file (8.2 kB). View file
 
app/dinov2/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ __version__ = "0.0.1"
app/dinov2/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (170 Bytes). View file
 
app/dinov2/configs/.DS_Store ADDED
Binary file (6.15 kB). View file
 
app/dinov2/configs/__init__.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import pathlib
7
+
8
+ from omegaconf import OmegaConf
9
+
10
+
11
+ def load_config(config_name: str):
12
+ config_filename = config_name + ".yaml"
13
+ return OmegaConf.load(pathlib.Path(__file__).parent.resolve() / config_filename)
14
+
15
+
16
+ dinov2_default_config = load_config("ssl_default_config")
17
+
18
+
19
+ def load_and_merge_config(config_name: str):
20
+ default_config = OmegaConf.create(dinov2_default_config)
21
+ loaded_config = load_config(config_name)
22
+ return OmegaConf.merge(default_config, loaded_config)
app/dinov2/configs/eval/vitb14_pretrain.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ student:
2
+ arch: vit_base
3
+ patch_size: 14
4
+ crops:
5
+ global_crops_size: 518 # this is to set up the position embeddings properly
6
+ local_crops_size: 98
app/dinov2/configs/eval/vitb14_reg4_pretrain.yaml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ student:
2
+ arch: vit_base
3
+ patch_size: 14
4
+ num_register_tokens: 4
5
+ interpolate_antialias: true
6
+ interpolate_offset: 0.0
7
+ crops:
8
+ global_crops_size: 518 # this is to set up the position embeddings properly
9
+ local_crops_size: 98
app/dinov2/configs/eval/vitg14_pretrain.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ student:
2
+ arch: vit_giant2
3
+ patch_size: 14
4
+ ffn_layer: swiglufused
5
+ crops:
6
+ global_crops_size: 518 # this is to set up the position embeddings properly
7
+ local_crops_size: 98
app/dinov2/configs/eval/vitg14_reg4_pretrain.yaml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ student:
2
+ arch: vit_giant2
3
+ patch_size: 14
4
+ ffn_layer: swiglufused
5
+ num_register_tokens: 4
6
+ interpolate_antialias: true
7
+ interpolate_offset: 0.0
8
+ crops:
9
+ global_crops_size: 518 # this is to set up the position embeddings properly
10
+ local_crops_size: 98
app/dinov2/configs/eval/vitl14_pretrain.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ student:
2
+ arch: vit_large
3
+ patch_size: 14
4
+ crops:
5
+ global_crops_size: 518 # this is to set up the position embeddings properly
6
+ local_crops_size: 98
app/dinov2/configs/eval/vitl14_reg4_pretrain.yaml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ student:
2
+ arch: vit_large
3
+ patch_size: 14
4
+ num_register_tokens: 4
5
+ interpolate_antialias: true
6
+ interpolate_offset: 0.0
7
+ crops:
8
+ global_crops_size: 518 # this is to set up the position embeddings properly
9
+ local_crops_size: 98
app/dinov2/configs/eval/vits14_pretrain.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ student:
2
+ arch: vit_small
3
+ patch_size: 14
4
+ crops:
5
+ global_crops_size: 518 # this is to set up the position embeddings properly
6
+ local_crops_size: 98
app/dinov2/configs/eval/vits14_reg4_pretrain.yaml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ student:
2
+ arch: vit_small
3
+ patch_size: 14
4
+ num_register_tokens: 4
5
+ interpolate_antialias: true
6
+ interpolate_offset: 0.0
7
+ crops:
8
+ global_crops_size: 518 # this is to set up the position embeddings properly
9
+ local_crops_size: 98
app/dinov2/configs/ssl_default_config.yaml ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MODEL:
2
+ WEIGHTS: ''
3
+ compute_precision:
4
+ grad_scaler: true
5
+ teacher:
6
+ backbone:
7
+ sharding_strategy: SHARD_GRAD_OP
8
+ mixed_precision:
9
+ param_dtype: fp16
10
+ reduce_dtype: fp16
11
+ buffer_dtype: fp32
12
+ dino_head:
13
+ sharding_strategy: SHARD_GRAD_OP
14
+ mixed_precision:
15
+ param_dtype: fp16
16
+ reduce_dtype: fp16
17
+ buffer_dtype: fp32
18
+ ibot_head:
19
+ sharding_strategy: SHARD_GRAD_OP
20
+ mixed_precision:
21
+ param_dtype: fp16
22
+ reduce_dtype: fp16
23
+ buffer_dtype: fp32
24
+ student:
25
+ backbone:
26
+ sharding_strategy: SHARD_GRAD_OP
27
+ mixed_precision:
28
+ param_dtype: fp16
29
+ reduce_dtype: fp16
30
+ buffer_dtype: fp32
31
+ dino_head:
32
+ sharding_strategy: SHARD_GRAD_OP
33
+ mixed_precision:
34
+ param_dtype: fp16
35
+ reduce_dtype: fp32
36
+ buffer_dtype: fp32
37
+ ibot_head:
38
+ sharding_strategy: SHARD_GRAD_OP
39
+ mixed_precision:
40
+ param_dtype: fp16
41
+ reduce_dtype: fp32
42
+ buffer_dtype: fp32
43
+ dino:
44
+ loss_weight: 1.0
45
+ head_n_prototypes: 65536
46
+ head_bottleneck_dim: 256
47
+ head_nlayers: 3
48
+ head_hidden_dim: 2048
49
+ koleo_loss_weight: 0.1
50
+ ibot:
51
+ loss_weight: 1.0
52
+ mask_sample_probability: 0.5
53
+ mask_ratio_min_max:
54
+ - 0.1
55
+ - 0.5
56
+ separate_head: false
57
+ head_n_prototypes: 65536
58
+ head_bottleneck_dim: 256
59
+ head_nlayers: 3
60
+ head_hidden_dim: 2048
61
+ train:
62
+ batch_size_per_gpu: 64
63
+ dataset_path: ImageNet:split=TRAIN
64
+ output_dir: .
65
+ saveckp_freq: 20
66
+ seed: 0
67
+ num_workers: 10
68
+ OFFICIAL_EPOCH_LENGTH: 1250
69
+ cache_dataset: true
70
+ centering: "centering" # or "sinkhorn_knopp"
71
+ student:
72
+ arch: vit_large
73
+ patch_size: 16
74
+ drop_path_rate: 0.3
75
+ layerscale: 1.0e-05
76
+ drop_path_uniform: true
77
+ pretrained_weights: ''
78
+ ffn_layer: "mlp"
79
+ block_chunks: 0
80
+ qkv_bias: true
81
+ proj_bias: true
82
+ ffn_bias: true
83
+ num_register_tokens: 0
84
+ interpolate_antialias: false
85
+ interpolate_offset: 0.1
86
+ teacher:
87
+ momentum_teacher: 0.992
88
+ final_momentum_teacher: 1
89
+ warmup_teacher_temp: 0.04
90
+ teacher_temp: 0.07
91
+ warmup_teacher_temp_epochs: 30
92
+ optim:
93
+ epochs: 100
94
+ weight_decay: 0.04
95
+ weight_decay_end: 0.4
96
+ base_lr: 0.004 # learning rate for a batch size of 1024
97
+ lr: 0. # will be set after applying scaling rule
98
+ warmup_epochs: 10
99
+ min_lr: 1.0e-06
100
+ clip_grad: 3.0
101
+ freeze_last_layer_epochs: 1
102
+ scaling_rule: sqrt_wrt_1024
103
+ patch_embed_lr_mult: 0.2
104
+ layerwise_decay: 0.9
105
+ adamw_beta1: 0.9
106
+ adamw_beta2: 0.999
107
+ crops:
108
+ global_crops_scale:
109
+ - 0.32
110
+ - 1.0
111
+ local_crops_number: 8
112
+ local_crops_scale:
113
+ - 0.05
114
+ - 0.32
115
+ global_crops_size: 224
116
+ local_crops_size: 96
117
+ evaluation:
118
+ eval_period_iterations: 12500
app/dinov2/configs/train/vitg14.yaml ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dino:
2
+ head_n_prototypes: 131072
3
+ head_bottleneck_dim: 384
4
+ ibot:
5
+ separate_head: true
6
+ head_n_prototypes: 131072
7
+ train:
8
+ batch_size_per_gpu: 12
9
+ dataset_path: ImageNet22k
10
+ centering: sinkhorn_knopp
11
+ student:
12
+ arch: vit_giant2
13
+ patch_size: 14
14
+ drop_path_rate: 0.4
15
+ ffn_layer: swiglufused
16
+ block_chunks: 4
17
+ teacher:
18
+ momentum_teacher: 0.994
19
+ optim:
20
+ epochs: 500
21
+ weight_decay_end: 0.2
22
+ base_lr: 2.0e-04 # learning rate for a batch size of 1024
23
+ warmup_epochs: 80
24
+ layerwise_decay: 1.0
25
+ crops:
26
+ local_crops_size: 98
app/dinov2/configs/train/vitl14.yaml ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dino:
2
+ head_n_prototypes: 131072
3
+ head_bottleneck_dim: 384
4
+ ibot:
5
+ separate_head: true
6
+ head_n_prototypes: 131072
7
+ train:
8
+ batch_size_per_gpu: 32
9
+ dataset_path: ImageNet22k
10
+ centering: sinkhorn_knopp
11
+ student:
12
+ arch: vit_large
13
+ patch_size: 14
14
+ drop_path_rate: 0.4
15
+ ffn_layer: swiglufused
16
+ block_chunks: 4
17
+ teacher:
18
+ momentum_teacher: 0.994
19
+ optim:
20
+ epochs: 500
21
+ weight_decay_end: 0.2
22
+ base_lr: 2.0e-04 # learning rate for a batch size of 1024
23
+ warmup_epochs: 80
24
+ layerwise_decay: 1.0
25
+ crops:
26
+ local_crops_size: 98
app/dinov2/configs/train/vitl16_short.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # this corresponds to the default config
2
+ train:
3
+ dataset_path: ImageNet:split=TRAIN
4
+ batch_size_per_gpu: 64
5
+ student:
6
+ block_chunks: 4
app/dinov2/data/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ from .adapters import DatasetWithEnumeratedTargets
7
+ from .loaders import make_data_loader, make_dataset, SamplerType
8
+ from .collate import collate_data_and_cast
9
+ from .masking import MaskingGenerator
10
+ from .augmentations import DataAugmentationDINO
app/dinov2/data/adapters.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ from typing import Any, Tuple
7
+
8
+ from torch.utils.data import Dataset
9
+
10
+
11
+ class DatasetWithEnumeratedTargets(Dataset):
12
+ def __init__(self, dataset):
13
+ self._dataset = dataset
14
+
15
+ def get_image_data(self, index: int) -> bytes:
16
+ return self._dataset.get_image_data(index)
17
+
18
+ def get_target(self, index: int) -> Tuple[Any, int]:
19
+ target = self._dataset.get_target(index)
20
+ return (index, target)
21
+
22
+ def __getitem__(self, index: int) -> Tuple[Any, Tuple[Any, int]]:
23
+ image, target = self._dataset[index]
24
+ target = index if target is None else target
25
+ return image, (index, target)
26
+
27
+ def __len__(self) -> int:
28
+ return len(self._dataset)
app/dinov2/data/augmentations.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import logging
7
+
8
+ from torchvision import transforms
9
+
10
+ from .transforms import (
11
+ GaussianBlur,
12
+ make_normalize_transform,
13
+ )
14
+
15
+
16
+ logger = logging.getLogger("dinov2")
17
+
18
+
19
+ class DataAugmentationDINO(object):
20
+ def __init__(
21
+ self,
22
+ global_crops_scale,
23
+ local_crops_scale,
24
+ local_crops_number,
25
+ global_crops_size=224,
26
+ local_crops_size=96,
27
+ ):
28
+ self.global_crops_scale = global_crops_scale
29
+ self.local_crops_scale = local_crops_scale
30
+ self.local_crops_number = local_crops_number
31
+ self.global_crops_size = global_crops_size
32
+ self.local_crops_size = local_crops_size
33
+
34
+ logger.info("###################################")
35
+ logger.info("Using data augmentation parameters:")
36
+ logger.info(f"global_crops_scale: {global_crops_scale}")
37
+ logger.info(f"local_crops_scale: {local_crops_scale}")
38
+ logger.info(f"local_crops_number: {local_crops_number}")
39
+ logger.info(f"global_crops_size: {global_crops_size}")
40
+ logger.info(f"local_crops_size: {local_crops_size}")
41
+ logger.info("###################################")
42
+
43
+ # random resized crop and flip
44
+ self.geometric_augmentation_global = transforms.Compose(
45
+ [
46
+ transforms.RandomResizedCrop(
47
+ global_crops_size, scale=global_crops_scale, interpolation=transforms.InterpolationMode.BICUBIC
48
+ ),
49
+ transforms.RandomHorizontalFlip(p=0.5),
50
+ ]
51
+ )
52
+
53
+ self.geometric_augmentation_local = transforms.Compose(
54
+ [
55
+ transforms.RandomResizedCrop(
56
+ local_crops_size, scale=local_crops_scale, interpolation=transforms.InterpolationMode.BICUBIC
57
+ ),
58
+ transforms.RandomHorizontalFlip(p=0.5),
59
+ ]
60
+ )
61
+
62
+ # color distorsions / blurring
63
+ color_jittering = transforms.Compose(
64
+ [
65
+ transforms.RandomApply(
66
+ [transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1)],
67
+ p=0.8,
68
+ ),
69
+ transforms.RandomGrayscale(p=0.2),
70
+ ]
71
+ )
72
+
73
+ global_transfo1_extra = GaussianBlur(p=1.0)
74
+
75
+ global_transfo2_extra = transforms.Compose(
76
+ [
77
+ GaussianBlur(p=0.1),
78
+ transforms.RandomSolarize(threshold=128, p=0.2),
79
+ ]
80
+ )
81
+
82
+ local_transfo_extra = GaussianBlur(p=0.5)
83
+
84
+ # normalization
85
+ self.normalize = transforms.Compose(
86
+ [
87
+ transforms.ToTensor(),
88
+ make_normalize_transform(),
89
+ ]
90
+ )
91
+
92
+ self.global_transfo1 = transforms.Compose([color_jittering, global_transfo1_extra, self.normalize])
93
+ self.global_transfo2 = transforms.Compose([color_jittering, global_transfo2_extra, self.normalize])
94
+ self.local_transfo = transforms.Compose([color_jittering, local_transfo_extra, self.normalize])
95
+
96
+ def __call__(self, image):
97
+ output = {}
98
+
99
+ # global crops:
100
+ im1_base = self.geometric_augmentation_global(image)
101
+ global_crop_1 = self.global_transfo1(im1_base)
102
+
103
+ im2_base = self.geometric_augmentation_global(image)
104
+ global_crop_2 = self.global_transfo2(im2_base)
105
+
106
+ output["global_crops"] = [global_crop_1, global_crop_2]
107
+
108
+ # global crops for teacher:
109
+ output["global_crops_teacher"] = [global_crop_1, global_crop_2]
110
+
111
+ # local crops:
112
+ local_crops = [
113
+ self.local_transfo(self.geometric_augmentation_local(image)) for _ in range(self.local_crops_number)
114
+ ]
115
+ output["local_crops"] = local_crops
116
+ output["offsets"] = ()
117
+
118
+ return output
app/dinov2/data/collate.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ import random
8
+
9
+
10
+ def collate_data_and_cast(samples_list, mask_ratio_tuple, mask_probability, dtype, n_tokens=None, mask_generator=None):
11
+ # dtype = torch.half # TODO: Remove
12
+
13
+ n_global_crops = len(samples_list[0][0]["global_crops"])
14
+ n_local_crops = len(samples_list[0][0]["local_crops"])
15
+
16
+ collated_global_crops = torch.stack([s[0]["global_crops"][i] for i in range(n_global_crops) for s in samples_list])
17
+
18
+ collated_local_crops = torch.stack([s[0]["local_crops"][i] for i in range(n_local_crops) for s in samples_list])
19
+
20
+ B = len(collated_global_crops)
21
+ N = n_tokens
22
+ n_samples_masked = int(B * mask_probability)
23
+ probs = torch.linspace(*mask_ratio_tuple, n_samples_masked + 1)
24
+ upperbound = 0
25
+ masks_list = []
26
+ for i in range(0, n_samples_masked):
27
+ prob_min = probs[i]
28
+ prob_max = probs[i + 1]
29
+ masks_list.append(torch.BoolTensor(mask_generator(int(N * random.uniform(prob_min, prob_max)))))
30
+ upperbound += int(N * prob_max)
31
+ for i in range(n_samples_masked, B):
32
+ masks_list.append(torch.BoolTensor(mask_generator(0)))
33
+
34
+ random.shuffle(masks_list)
35
+
36
+ collated_masks = torch.stack(masks_list).flatten(1)
37
+ mask_indices_list = collated_masks.flatten().nonzero().flatten()
38
+
39
+ masks_weight = (1 / collated_masks.sum(-1).clamp(min=1.0)).unsqueeze(-1).expand_as(collated_masks)[collated_masks]
40
+
41
+ return {
42
+ "collated_global_crops": collated_global_crops.to(dtype),
43
+ "collated_local_crops": collated_local_crops.to(dtype),
44
+ "collated_masks": collated_masks,
45
+ "mask_indices_list": mask_indices_list,
46
+ "masks_weight": masks_weight,
47
+ "upperbound": upperbound,
48
+ "n_masked_patches": torch.full((1,), fill_value=mask_indices_list.shape[0], dtype=torch.long),
49
+ }
app/dinov2/data/datasets/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ from .image_net import ImageNet
7
+ from .image_net_22k import ImageNet22k
app/dinov2/data/datasets/decoders.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ from io import BytesIO
7
+ from typing import Any
8
+
9
+ from PIL import Image
10
+
11
+
12
+ class Decoder:
13
+ def decode(self) -> Any:
14
+ raise NotImplementedError
15
+
16
+
17
+ class ImageDataDecoder(Decoder):
18
+ def __init__(self, image_data: bytes) -> None:
19
+ self._image_data = image_data
20
+
21
+ def decode(self) -> Image:
22
+ f = BytesIO(self._image_data)
23
+ return Image.open(f).convert(mode="RGB")
24
+
25
+
26
+ class TargetDecoder(Decoder):
27
+ def __init__(self, target: Any):
28
+ self._target = target
29
+
30
+ def decode(self) -> Any:
31
+ return self._target
app/dinov2/data/datasets/extended.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ from typing import Any, Tuple
7
+
8
+ from torchvision.datasets import VisionDataset
9
+
10
+ from .decoders import TargetDecoder, ImageDataDecoder
11
+
12
+
13
+ class ExtendedVisionDataset(VisionDataset):
14
+ def __init__(self, *args, **kwargs) -> None:
15
+ super().__init__(*args, **kwargs) # type: ignore
16
+
17
+ def get_image_data(self, index: int) -> bytes:
18
+ raise NotImplementedError
19
+
20
+ def get_target(self, index: int) -> Any:
21
+ raise NotImplementedError
22
+
23
+ def __getitem__(self, index: int) -> Tuple[Any, Any]:
24
+ try:
25
+ image_data = self.get_image_data(index)
26
+ image = ImageDataDecoder(image_data).decode()
27
+ except Exception as e:
28
+ raise RuntimeError(f"can not read image for sample {index}") from e
29
+ target = self.get_target(index)
30
+ target = TargetDecoder(target).decode()
31
+
32
+ if self.transforms is not None:
33
+ image, target = self.transforms(image, target)
34
+
35
+ return image, target
36
+
37
+ def __len__(self) -> int:
38
+ raise NotImplementedError
app/dinov2/data/datasets/image_net.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import csv
7
+ from enum import Enum
8
+ import logging
9
+ import os
10
+ from typing import Callable, List, Optional, Tuple, Union
11
+
12
+ import numpy as np
13
+
14
+ from .extended import ExtendedVisionDataset
15
+
16
+
17
+ logger = logging.getLogger("dinov2")
18
+ _Target = int
19
+
20
+
21
+ class _Split(Enum):
22
+ TRAIN = "train"
23
+ VAL = "val"
24
+ TEST = "test" # NOTE: torchvision does not support the test split
25
+
26
+ @property
27
+ def length(self) -> int:
28
+ split_lengths = {
29
+ _Split.TRAIN: 1_281_167,
30
+ _Split.VAL: 50_000,
31
+ _Split.TEST: 100_000,
32
+ }
33
+ return split_lengths[self]
34
+
35
+ def get_dirname(self, class_id: Optional[str] = None) -> str:
36
+ return self.value if class_id is None else os.path.join(self.value, class_id)
37
+
38
+ def get_image_relpath(self, actual_index: int, class_id: Optional[str] = None) -> str:
39
+ dirname = self.get_dirname(class_id)
40
+ if self == _Split.TRAIN:
41
+ basename = f"{class_id}_{actual_index}"
42
+ else: # self in (_Split.VAL, _Split.TEST):
43
+ basename = f"ILSVRC2012_{self.value}_{actual_index:08d}"
44
+ return os.path.join(dirname, basename + ".JPEG")
45
+
46
+ def parse_image_relpath(self, image_relpath: str) -> Tuple[str, int]:
47
+ assert self != _Split.TEST
48
+ dirname, filename = os.path.split(image_relpath)
49
+ class_id = os.path.split(dirname)[-1]
50
+ basename, _ = os.path.splitext(filename)
51
+ actual_index = int(basename.split("_")[-1])
52
+ return class_id, actual_index
53
+
54
+
55
+ class ImageNet(ExtendedVisionDataset):
56
+ Target = Union[_Target]
57
+ Split = Union[_Split]
58
+
59
+ def __init__(
60
+ self,
61
+ *,
62
+ split: "ImageNet.Split",
63
+ root: str,
64
+ extra: str,
65
+ transforms: Optional[Callable] = None,
66
+ transform: Optional[Callable] = None,
67
+ target_transform: Optional[Callable] = None,
68
+ ) -> None:
69
+ super().__init__(root, transforms, transform, target_transform)
70
+ self._extra_root = extra
71
+ self._split = split
72
+
73
+ self._entries = None
74
+ self._class_ids = None
75
+ self._class_names = None
76
+
77
+ @property
78
+ def split(self) -> "ImageNet.Split":
79
+ return self._split
80
+
81
+ def _get_extra_full_path(self, extra_path: str) -> str:
82
+ return os.path.join(self._extra_root, extra_path)
83
+
84
+ def _load_extra(self, extra_path: str) -> np.ndarray:
85
+ extra_full_path = self._get_extra_full_path(extra_path)
86
+ return np.load(extra_full_path, mmap_mode="r")
87
+
88
+ def _save_extra(self, extra_array: np.ndarray, extra_path: str) -> None:
89
+ extra_full_path = self._get_extra_full_path(extra_path)
90
+ os.makedirs(self._extra_root, exist_ok=True)
91
+ np.save(extra_full_path, extra_array)
92
+
93
+ @property
94
+ def _entries_path(self) -> str:
95
+ return f"entries-{self._split.value.upper()}.npy"
96
+
97
+ @property
98
+ def _class_ids_path(self) -> str:
99
+ return f"class-ids-{self._split.value.upper()}.npy"
100
+
101
+ @property
102
+ def _class_names_path(self) -> str:
103
+ return f"class-names-{self._split.value.upper()}.npy"
104
+
105
+ def _get_entries(self) -> np.ndarray:
106
+ if self._entries is None:
107
+ self._entries = self._load_extra(self._entries_path)
108
+ assert self._entries is not None
109
+ return self._entries
110
+
111
+ def _get_class_ids(self) -> np.ndarray:
112
+ if self._split == _Split.TEST:
113
+ assert False, "Class IDs are not available in TEST split"
114
+ if self._class_ids is None:
115
+ self._class_ids = self._load_extra(self._class_ids_path)
116
+ assert self._class_ids is not None
117
+ return self._class_ids
118
+
119
+ def _get_class_names(self) -> np.ndarray:
120
+ if self._split == _Split.TEST:
121
+ assert False, "Class names are not available in TEST split"
122
+ if self._class_names is None:
123
+ self._class_names = self._load_extra(self._class_names_path)
124
+ assert self._class_names is not None
125
+ return self._class_names
126
+
127
+ def find_class_id(self, class_index: int) -> str:
128
+ class_ids = self._get_class_ids()
129
+ return str(class_ids[class_index])
130
+
131
+ def find_class_name(self, class_index: int) -> str:
132
+ class_names = self._get_class_names()
133
+ return str(class_names[class_index])
134
+
135
+ def get_image_data(self, index: int) -> bytes:
136
+ entries = self._get_entries()
137
+ actual_index = entries[index]["actual_index"]
138
+
139
+ class_id = self.get_class_id(index)
140
+
141
+ image_relpath = self.split.get_image_relpath(actual_index, class_id)
142
+ image_full_path = os.path.join(self.root, image_relpath)
143
+ with open(image_full_path, mode="rb") as f:
144
+ image_data = f.read()
145
+ return image_data
146
+
147
+ def get_target(self, index: int) -> Optional[Target]:
148
+ entries = self._get_entries()
149
+ class_index = entries[index]["class_index"]
150
+ return None if self.split == _Split.TEST else int(class_index)
151
+
152
+ def get_targets(self) -> Optional[np.ndarray]:
153
+ entries = self._get_entries()
154
+ return None if self.split == _Split.TEST else entries["class_index"]
155
+
156
+ def get_class_id(self, index: int) -> Optional[str]:
157
+ entries = self._get_entries()
158
+ class_id = entries[index]["class_id"]
159
+ return None if self.split == _Split.TEST else str(class_id)
160
+
161
+ def get_class_name(self, index: int) -> Optional[str]:
162
+ entries = self._get_entries()
163
+ class_name = entries[index]["class_name"]
164
+ return None if self.split == _Split.TEST else str(class_name)
165
+
166
+ def __len__(self) -> int:
167
+ entries = self._get_entries()
168
+ assert len(entries) == self.split.length
169
+ return len(entries)
170
+
171
+ def _load_labels(self, labels_path: str) -> List[Tuple[str, str]]:
172
+ labels_full_path = os.path.join(self.root, labels_path)
173
+ labels = []
174
+
175
+ try:
176
+ with open(labels_full_path, "r") as f:
177
+ reader = csv.reader(f)
178
+ for row in reader:
179
+ class_id, class_name = row
180
+ labels.append((class_id, class_name))
181
+ except OSError as e:
182
+ raise RuntimeError(f'can not read labels file "{labels_full_path}"') from e
183
+
184
+ return labels
185
+
186
+ def _dump_entries(self) -> None:
187
+ split = self.split
188
+ if split == ImageNet.Split.TEST:
189
+ dataset = None
190
+ sample_count = split.length
191
+ max_class_id_length, max_class_name_length = 0, 0
192
+ else:
193
+ labels_path = "labels.txt"
194
+ logger.info(f'loading labels from "{labels_path}"')
195
+ labels = self._load_labels(labels_path)
196
+
197
+ # NOTE: Using torchvision ImageFolder for consistency
198
+ from torchvision.datasets import ImageFolder
199
+
200
+ dataset_root = os.path.join(self.root, split.get_dirname())
201
+ dataset = ImageFolder(dataset_root)
202
+ sample_count = len(dataset)
203
+ max_class_id_length, max_class_name_length = -1, -1
204
+ for sample in dataset.samples:
205
+ _, class_index = sample
206
+ class_id, class_name = labels[class_index]
207
+ max_class_id_length = max(len(class_id), max_class_id_length)
208
+ max_class_name_length = max(len(class_name), max_class_name_length)
209
+
210
+ dtype = np.dtype(
211
+ [
212
+ ("actual_index", "<u4"),
213
+ ("class_index", "<u4"),
214
+ ("class_id", f"U{max_class_id_length}"),
215
+ ("class_name", f"U{max_class_name_length}"),
216
+ ]
217
+ )
218
+ entries_array = np.empty(sample_count, dtype=dtype)
219
+
220
+ if split == ImageNet.Split.TEST:
221
+ old_percent = -1
222
+ for index in range(sample_count):
223
+ percent = 100 * (index + 1) // sample_count
224
+ if percent > old_percent:
225
+ logger.info(f"creating entries: {percent}%")
226
+ old_percent = percent
227
+
228
+ actual_index = index + 1
229
+ class_index = np.uint32(-1)
230
+ class_id, class_name = "", ""
231
+ entries_array[index] = (actual_index, class_index, class_id, class_name)
232
+ else:
233
+ class_names = {class_id: class_name for class_id, class_name in labels}
234
+
235
+ assert dataset
236
+ old_percent = -1
237
+ for index in range(sample_count):
238
+ percent = 100 * (index + 1) // sample_count
239
+ if percent > old_percent:
240
+ logger.info(f"creating entries: {percent}%")
241
+ old_percent = percent
242
+
243
+ image_full_path, class_index = dataset.samples[index]
244
+ image_relpath = os.path.relpath(image_full_path, self.root)
245
+ class_id, actual_index = split.parse_image_relpath(image_relpath)
246
+ class_name = class_names[class_id]
247
+ entries_array[index] = (actual_index, class_index, class_id, class_name)
248
+
249
+ logger.info(f'saving entries to "{self._entries_path}"')
250
+ self._save_extra(entries_array, self._entries_path)
251
+
252
+ def _dump_class_ids_and_names(self) -> None:
253
+ split = self.split
254
+ if split == ImageNet.Split.TEST:
255
+ return
256
+
257
+ entries_array = self._load_extra(self._entries_path)
258
+
259
+ max_class_id_length, max_class_name_length, max_class_index = -1, -1, -1
260
+ for entry in entries_array:
261
+ class_index, class_id, class_name = (
262
+ entry["class_index"],
263
+ entry["class_id"],
264
+ entry["class_name"],
265
+ )
266
+ max_class_index = max(int(class_index), max_class_index)
267
+ max_class_id_length = max(len(str(class_id)), max_class_id_length)
268
+ max_class_name_length = max(len(str(class_name)), max_class_name_length)
269
+
270
+ class_count = max_class_index + 1
271
+ class_ids_array = np.empty(class_count, dtype=f"U{max_class_id_length}")
272
+ class_names_array = np.empty(class_count, dtype=f"U{max_class_name_length}")
273
+ for entry in entries_array:
274
+ class_index, class_id, class_name = (
275
+ entry["class_index"],
276
+ entry["class_id"],
277
+ entry["class_name"],
278
+ )
279
+ class_ids_array[class_index] = class_id
280
+ class_names_array[class_index] = class_name
281
+
282
+ logger.info(f'saving class IDs to "{self._class_ids_path}"')
283
+ self._save_extra(class_ids_array, self._class_ids_path)
284
+
285
+ logger.info(f'saving class names to "{self._class_names_path}"')
286
+ self._save_extra(class_names_array, self._class_names_path)
287
+
288
+ def dump_extra(self) -> None:
289
+ self._dump_entries()
290
+ self._dump_class_ids_and_names()
app/dinov2/data/datasets/image_net_22k.py ADDED
@@ -0,0 +1,302 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ from dataclasses import dataclass
7
+ from enum import Enum
8
+ from functools import lru_cache
9
+ from gzip import GzipFile
10
+ from io import BytesIO
11
+ from mmap import ACCESS_READ, mmap
12
+ import os
13
+ from typing import Any, Callable, List, Optional, Set, Tuple
14
+ import warnings
15
+
16
+ import numpy as np
17
+
18
+ from .extended import ExtendedVisionDataset
19
+
20
+
21
+ _Labels = int
22
+
23
+ _DEFAULT_MMAP_CACHE_SIZE = 16 # Warning: This can exhaust file descriptors
24
+
25
+
26
+ @dataclass
27
+ class _ClassEntry:
28
+ block_offset: int
29
+ maybe_filename: Optional[str] = None
30
+
31
+
32
+ @dataclass
33
+ class _Entry:
34
+ class_index: int # noqa: E701
35
+ start_offset: int
36
+ end_offset: int
37
+ filename: str
38
+
39
+
40
+ class _Split(Enum):
41
+ TRAIN = "train"
42
+ VAL = "val"
43
+
44
+ @property
45
+ def length(self) -> int:
46
+ return {
47
+ _Split.TRAIN: 11_797_647,
48
+ _Split.VAL: 561_050,
49
+ }[self]
50
+
51
+ def entries_path(self):
52
+ return f"imagenet21kp_{self.value}.txt"
53
+
54
+
55
+ def _get_tarball_path(class_id: str) -> str:
56
+ return f"{class_id}.tar"
57
+
58
+
59
+ def _make_mmap_tarball(tarballs_root: str, mmap_cache_size: int):
60
+ @lru_cache(maxsize=mmap_cache_size)
61
+ def _mmap_tarball(class_id: str) -> mmap:
62
+ tarball_path = _get_tarball_path(class_id)
63
+ tarball_full_path = os.path.join(tarballs_root, tarball_path)
64
+ with open(tarball_full_path) as f:
65
+ return mmap(fileno=f.fileno(), length=0, access=ACCESS_READ)
66
+
67
+ return _mmap_tarball
68
+
69
+
70
+ class ImageNet22k(ExtendedVisionDataset):
71
+ _GZIPPED_INDICES: Set[int] = {
72
+ 841_545,
73
+ 1_304_131,
74
+ 2_437_921,
75
+ 2_672_079,
76
+ 2_795_676,
77
+ 2_969_786,
78
+ 6_902_965,
79
+ 6_903_550,
80
+ 6_903_628,
81
+ 7_432_557,
82
+ 7_432_589,
83
+ 7_813_809,
84
+ 8_329_633,
85
+ 10_296_990,
86
+ 10_417_652,
87
+ 10_492_265,
88
+ 10_598_078,
89
+ 10_782_398,
90
+ 10_902_612,
91
+ 11_203_736,
92
+ 11_342_890,
93
+ 11_397_596,
94
+ 11_589_762,
95
+ 11_705_103,
96
+ 12_936_875,
97
+ 13_289_782,
98
+ }
99
+ Labels = _Labels
100
+
101
+ def __init__(
102
+ self,
103
+ *,
104
+ root: str,
105
+ extra: str,
106
+ transforms: Optional[Callable] = None,
107
+ transform: Optional[Callable] = None,
108
+ target_transform: Optional[Callable] = None,
109
+ mmap_cache_size: int = _DEFAULT_MMAP_CACHE_SIZE,
110
+ ) -> None:
111
+ super().__init__(root, transforms, transform, target_transform)
112
+ self._extra_root = extra
113
+
114
+ entries_path = self._get_entries_path(root)
115
+ self._entries = self._load_extra(entries_path)
116
+
117
+ class_ids_path = self._get_class_ids_path(root)
118
+ self._class_ids = self._load_extra(class_ids_path)
119
+
120
+ self._gzipped_indices = ImageNet22k._GZIPPED_INDICES
121
+ self._mmap_tarball = _make_mmap_tarball(self._tarballs_root, mmap_cache_size)
122
+
123
+ def _get_entries_path(self, root: Optional[str] = None) -> str:
124
+ return "entries.npy"
125
+
126
+ def _get_class_ids_path(self, root: Optional[str] = None) -> str:
127
+ return "class-ids.npy"
128
+
129
+ def _find_class_ids(self, path: str) -> List[str]:
130
+ class_ids = []
131
+
132
+ with os.scandir(path) as entries:
133
+ for entry in entries:
134
+ root, ext = os.path.splitext(entry.name)
135
+ if ext != ".tar":
136
+ continue
137
+ class_ids.append(root)
138
+
139
+ return sorted(class_ids)
140
+
141
+ def _load_entries_class_ids(self, root: Optional[str] = None) -> Tuple[List[_Entry], List[str]]:
142
+ root = self.get_root(root)
143
+ entries: List[_Entry] = []
144
+ class_ids = self._find_class_ids(root)
145
+
146
+ for class_index, class_id in enumerate(class_ids):
147
+ path = os.path.join(root, "blocks", f"{class_id}.log")
148
+ class_entries = []
149
+
150
+ try:
151
+ with open(path) as f:
152
+ for line in f:
153
+ line = line.rstrip()
154
+ block, filename = line.split(":")
155
+ block_offset = int(block[6:])
156
+ filename = filename[1:]
157
+
158
+ maybe_filename = None
159
+ if filename != "** Block of NULs **":
160
+ maybe_filename = filename
161
+ _, ext = os.path.splitext(filename)
162
+ # assert ext == ".JPEG"
163
+
164
+ class_entry = _ClassEntry(block_offset, maybe_filename)
165
+ class_entries.append(class_entry)
166
+ except OSError as e:
167
+ raise RuntimeError(f'can not read blocks file "{path}"') from e
168
+
169
+ assert class_entries[-1].maybe_filename is None
170
+
171
+ for class_entry1, class_entry2 in zip(class_entries, class_entries[1:]):
172
+ assert class_entry1.block_offset <= class_entry2.block_offset
173
+ start_offset = 512 * class_entry1.block_offset
174
+ end_offset = 512 * class_entry2.block_offset
175
+ assert class_entry1.maybe_filename is not None
176
+ filename = class_entry1.maybe_filename
177
+ entry = _Entry(class_index, start_offset, end_offset, filename)
178
+ # Skip invalid image files (PIL throws UnidentifiedImageError)
179
+ if filename == "n06470073_47249.JPEG":
180
+ continue
181
+ entries.append(entry)
182
+
183
+ return entries, class_ids
184
+
185
+ def _load_extra(self, extra_path: str) -> np.ndarray:
186
+ extra_root = self._extra_root
187
+ extra_full_path = os.path.join(extra_root, extra_path)
188
+ return np.load(extra_full_path, mmap_mode="r")
189
+
190
+ def _save_extra(self, extra_array: np.ndarray, extra_path: str) -> None:
191
+ extra_root = self._extra_root
192
+ extra_full_path = os.path.join(extra_root, extra_path)
193
+ os.makedirs(extra_root, exist_ok=True)
194
+ np.save(extra_full_path, extra_array)
195
+
196
+ @property
197
+ def _tarballs_root(self) -> str:
198
+ return self.root
199
+
200
+ def find_class_id(self, class_index: int) -> str:
201
+ return str(self._class_ids[class_index])
202
+
203
+ def get_image_data(self, index: int) -> bytes:
204
+ entry = self._entries[index]
205
+ class_id = entry["class_id"]
206
+ class_mmap = self._mmap_tarball(class_id)
207
+
208
+ start_offset, end_offset = entry["start_offset"], entry["end_offset"]
209
+ try:
210
+ mapped_data = class_mmap[start_offset:end_offset]
211
+ data = mapped_data[512:] # Skip entry header block
212
+
213
+ if len(data) >= 2 and tuple(data[:2]) == (0x1F, 0x8B):
214
+ assert index in self._gzipped_indices, f"unexpected gzip header for sample {index}"
215
+ with GzipFile(fileobj=BytesIO(data)) as g:
216
+ data = g.read()
217
+ except Exception as e:
218
+ raise RuntimeError(f"can not retrieve image data for sample {index} " f'from "{class_id}" tarball') from e
219
+
220
+ return data
221
+
222
+ def get_target(self, index: int) -> Any:
223
+ return int(self._entries[index]["class_index"])
224
+
225
+ def get_targets(self) -> np.ndarray:
226
+ return self._entries["class_index"]
227
+
228
+ def get_class_id(self, index: int) -> str:
229
+ return str(self._entries[index]["class_id"])
230
+
231
+ def get_class_ids(self) -> np.ndarray:
232
+ return self._entries["class_id"]
233
+
234
+ def __getitem__(self, index: int) -> Tuple[Any, Any]:
235
+ with warnings.catch_warnings():
236
+ warnings.simplefilter("ignore")
237
+ return super().__getitem__(index)
238
+
239
+ def __len__(self) -> int:
240
+ return len(self._entries)
241
+
242
+ def _dump_entries(self, *args, **kwargs) -> None:
243
+ entries, class_ids = self._load_entries_class_ids(*args, **kwargs)
244
+
245
+ max_class_id_length, max_filename_length, max_class_index = -1, -1, -1
246
+ for entry in entries:
247
+ class_id = class_ids[entry.class_index]
248
+ max_class_index = max(entry.class_index, max_class_index)
249
+ max_class_id_length = max(len(class_id), max_class_id_length)
250
+ max_filename_length = max(len(entry.filename), max_filename_length)
251
+
252
+ dtype = np.dtype(
253
+ [
254
+ ("class_index", "<u4"),
255
+ ("class_id", f"U{max_class_id_length}"),
256
+ ("start_offset", "<u4"),
257
+ ("end_offset", "<u4"),
258
+ ("filename", f"U{max_filename_length}"),
259
+ ]
260
+ )
261
+ sample_count = len(entries)
262
+ entries_array = np.empty(sample_count, dtype=dtype)
263
+ for i, entry in enumerate(entries):
264
+ class_index = entry.class_index
265
+ class_id = class_ids[class_index]
266
+ start_offset = entry.start_offset
267
+ end_offset = entry.end_offset
268
+ filename = entry.filename
269
+ entries_array[i] = (
270
+ class_index,
271
+ class_id,
272
+ start_offset,
273
+ end_offset,
274
+ filename,
275
+ )
276
+
277
+ entries_path = self._get_entries_path(*args, **kwargs)
278
+ self._save_extra(entries_array, entries_path)
279
+
280
+ def _dump_class_ids(self, *args, **kwargs) -> None:
281
+ entries_path = self._get_entries_path(*args, **kwargs)
282
+ entries_array = self._load_extra(entries_path)
283
+
284
+ max_class_id_length, max_class_index = -1, -1
285
+ for entry in entries_array:
286
+ class_index, class_id = entry["class_index"], entry["class_id"]
287
+ max_class_index = max(int(class_index), max_class_index)
288
+ max_class_id_length = max(len(str(class_id)), max_class_id_length)
289
+
290
+ class_ids_array = np.empty(max_class_index + 1, dtype=f"U{max_class_id_length}")
291
+ for entry in entries_array:
292
+ class_index, class_id = entry["class_index"], entry["class_id"]
293
+ class_ids_array[class_index] = class_id
294
+ class_ids_path = self._get_class_ids_path(*args, **kwargs)
295
+ self._save_extra(class_ids_array, class_ids_path)
296
+
297
+ def _dump_extra(self, *args, **kwargs) -> None:
298
+ self._dump_entries(*args, *kwargs)
299
+ self._dump_class_ids(*args, *kwargs)
300
+
301
+ def dump_extra(self, root: Optional[str] = None) -> None:
302
+ return self._dump_extra(root)
app/dinov2/data/loaders.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import logging
7
+ from enum import Enum
8
+ from typing import Any, Callable, List, Optional, TypeVar
9
+
10
+ import torch
11
+ from torch.utils.data import Sampler
12
+
13
+ from .datasets import ImageNet, ImageNet22k
14
+ from .samplers import EpochSampler, InfiniteSampler, ShardedInfiniteSampler
15
+
16
+
17
+ logger = logging.getLogger("dinov2")
18
+
19
+
20
+ class SamplerType(Enum):
21
+ DISTRIBUTED = 0
22
+ EPOCH = 1
23
+ INFINITE = 2
24
+ SHARDED_INFINITE = 3
25
+ SHARDED_INFINITE_NEW = 4
26
+
27
+
28
+ def _make_bool_str(b: bool) -> str:
29
+ return "yes" if b else "no"
30
+
31
+
32
+ def _make_sample_transform(image_transform: Optional[Callable] = None, target_transform: Optional[Callable] = None):
33
+ def transform(sample):
34
+ image, target = sample
35
+ if image_transform is not None:
36
+ image = image_transform(image)
37
+ if target_transform is not None:
38
+ target = target_transform(target)
39
+ return image, target
40
+
41
+ return transform
42
+
43
+
44
+ def _parse_dataset_str(dataset_str: str):
45
+ tokens = dataset_str.split(":")
46
+
47
+ name = tokens[0]
48
+ kwargs = {}
49
+
50
+ for token in tokens[1:]:
51
+ key, value = token.split("=")
52
+ assert key in ("root", "extra", "split")
53
+ kwargs[key] = value
54
+
55
+ if name == "ImageNet":
56
+ class_ = ImageNet
57
+ if "split" in kwargs:
58
+ kwargs["split"] = ImageNet.Split[kwargs["split"]]
59
+ elif name == "ImageNet22k":
60
+ class_ = ImageNet22k
61
+ else:
62
+ raise ValueError(f'Unsupported dataset "{name}"')
63
+
64
+ return class_, kwargs
65
+
66
+
67
+ def make_dataset(
68
+ *,
69
+ dataset_str: str,
70
+ transform: Optional[Callable] = None,
71
+ target_transform: Optional[Callable] = None,
72
+ ):
73
+ """
74
+ Creates a dataset with the specified parameters.
75
+
76
+ Args:
77
+ dataset_str: A dataset string description (e.g. ImageNet:split=TRAIN).
78
+ transform: A transform to apply to images.
79
+ target_transform: A transform to apply to targets.
80
+
81
+ Returns:
82
+ The created dataset.
83
+ """
84
+ logger.info(f'using dataset: "{dataset_str}"')
85
+
86
+ class_, kwargs = _parse_dataset_str(dataset_str)
87
+ dataset = class_(transform=transform, target_transform=target_transform, **kwargs)
88
+
89
+ logger.info(f"# of dataset samples: {len(dataset):,d}")
90
+
91
+ # Aggregated datasets do not expose (yet) these attributes, so add them.
92
+ if not hasattr(dataset, "transform"):
93
+ setattr(dataset, "transform", transform)
94
+ if not hasattr(dataset, "target_transform"):
95
+ setattr(dataset, "target_transform", target_transform)
96
+
97
+ return dataset
98
+
99
+
100
+ def _make_sampler(
101
+ *,
102
+ dataset,
103
+ type: Optional[SamplerType] = None,
104
+ shuffle: bool = False,
105
+ seed: int = 0,
106
+ size: int = -1,
107
+ advance: int = 0,
108
+ ) -> Optional[Sampler]:
109
+ sample_count = len(dataset)
110
+
111
+ if type == SamplerType.INFINITE:
112
+ logger.info("sampler: infinite")
113
+ if size > 0:
114
+ raise ValueError("sampler size > 0 is invalid")
115
+ return InfiniteSampler(
116
+ sample_count=sample_count,
117
+ shuffle=shuffle,
118
+ seed=seed,
119
+ advance=advance,
120
+ )
121
+ elif type in (SamplerType.SHARDED_INFINITE, SamplerType.SHARDED_INFINITE_NEW):
122
+ logger.info("sampler: sharded infinite")
123
+ if size > 0:
124
+ raise ValueError("sampler size > 0 is invalid")
125
+ # TODO: Remove support for old shuffling
126
+ use_new_shuffle_tensor_slice = type == SamplerType.SHARDED_INFINITE_NEW
127
+ return ShardedInfiniteSampler(
128
+ sample_count=sample_count,
129
+ shuffle=shuffle,
130
+ seed=seed,
131
+ advance=advance,
132
+ use_new_shuffle_tensor_slice=use_new_shuffle_tensor_slice,
133
+ )
134
+ elif type == SamplerType.EPOCH:
135
+ logger.info("sampler: epoch")
136
+ if advance > 0:
137
+ raise NotImplementedError("sampler advance > 0 is not supported")
138
+ size = size if size > 0 else sample_count
139
+ logger.info(f"# of samples / epoch: {size:,d}")
140
+ return EpochSampler(
141
+ size=size,
142
+ sample_count=sample_count,
143
+ shuffle=shuffle,
144
+ seed=seed,
145
+ )
146
+ elif type == SamplerType.DISTRIBUTED:
147
+ logger.info("sampler: distributed")
148
+ if size > 0:
149
+ raise ValueError("sampler size > 0 is invalid")
150
+ if advance > 0:
151
+ raise ValueError("sampler advance > 0 is invalid")
152
+ return torch.utils.data.DistributedSampler(
153
+ dataset=dataset,
154
+ shuffle=shuffle,
155
+ seed=seed,
156
+ drop_last=False,
157
+ )
158
+
159
+ logger.info("sampler: none")
160
+ return None
161
+
162
+
163
+ T = TypeVar("T")
164
+
165
+
166
+ def make_data_loader(
167
+ *,
168
+ dataset,
169
+ batch_size: int,
170
+ num_workers: int,
171
+ shuffle: bool = True,
172
+ seed: int = 0,
173
+ sampler_type: Optional[SamplerType] = SamplerType.INFINITE,
174
+ sampler_size: int = -1,
175
+ sampler_advance: int = 0,
176
+ drop_last: bool = True,
177
+ persistent_workers: bool = False,
178
+ collate_fn: Optional[Callable[[List[T]], Any]] = None,
179
+ ):
180
+ """
181
+ Creates a data loader with the specified parameters.
182
+
183
+ Args:
184
+ dataset: A dataset (third party, LaViDa or WebDataset).
185
+ batch_size: The size of batches to generate.
186
+ num_workers: The number of workers to use.
187
+ shuffle: Whether to shuffle samples.
188
+ seed: The random seed to use.
189
+ sampler_type: Which sampler to use: EPOCH, INFINITE, SHARDED_INFINITE, SHARDED_INFINITE_NEW, DISTRIBUTED or None.
190
+ sampler_size: The number of images per epoch (when applicable) or -1 for the entire dataset.
191
+ sampler_advance: How many samples to skip (when applicable).
192
+ drop_last: Whether the last non-full batch of data should be dropped.
193
+ persistent_workers: maintain the workers Dataset instances alive after a dataset has been consumed once.
194
+ collate_fn: Function that performs batch collation
195
+ """
196
+
197
+ sampler = _make_sampler(
198
+ dataset=dataset,
199
+ type=sampler_type,
200
+ shuffle=shuffle,
201
+ seed=seed,
202
+ size=sampler_size,
203
+ advance=sampler_advance,
204
+ )
205
+
206
+ logger.info("using PyTorch data loader")
207
+ data_loader = torch.utils.data.DataLoader(
208
+ dataset,
209
+ sampler=sampler,
210
+ batch_size=batch_size,
211
+ num_workers=num_workers,
212
+ pin_memory=True,
213
+ drop_last=drop_last,
214
+ persistent_workers=persistent_workers,
215
+ collate_fn=collate_fn,
216
+ )
217
+
218
+ try:
219
+ logger.info(f"# of batches: {len(data_loader):,d}")
220
+ except TypeError: # data loader has no length
221
+ logger.info("infinite data loader")
222
+ return data_loader
app/dinov2/data/masking.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import random
7
+ import math
8
+ import numpy as np
9
+
10
+
11
+ class MaskingGenerator:
12
+ def __init__(
13
+ self,
14
+ input_size,
15
+ num_masking_patches=None,
16
+ min_num_patches=4,
17
+ max_num_patches=None,
18
+ min_aspect=0.3,
19
+ max_aspect=None,
20
+ ):
21
+ if not isinstance(input_size, tuple):
22
+ input_size = (input_size,) * 2
23
+ self.height, self.width = input_size
24
+
25
+ self.num_patches = self.height * self.width
26
+ self.num_masking_patches = num_masking_patches
27
+
28
+ self.min_num_patches = min_num_patches
29
+ self.max_num_patches = num_masking_patches if max_num_patches is None else max_num_patches
30
+
31
+ max_aspect = max_aspect or 1 / min_aspect
32
+ self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect))
33
+
34
+ def __repr__(self):
35
+ repr_str = "Generator(%d, %d -> [%d ~ %d], max = %d, %.3f ~ %.3f)" % (
36
+ self.height,
37
+ self.width,
38
+ self.min_num_patches,
39
+ self.max_num_patches,
40
+ self.num_masking_patches,
41
+ self.log_aspect_ratio[0],
42
+ self.log_aspect_ratio[1],
43
+ )
44
+ return repr_str
45
+
46
+ def get_shape(self):
47
+ return self.height, self.width
48
+
49
+ def _mask(self, mask, max_mask_patches):
50
+ delta = 0
51
+ for _ in range(10):
52
+ target_area = random.uniform(self.min_num_patches, max_mask_patches)
53
+ aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
54
+ h = int(round(math.sqrt(target_area * aspect_ratio)))
55
+ w = int(round(math.sqrt(target_area / aspect_ratio)))
56
+ if w < self.width and h < self.height:
57
+ top = random.randint(0, self.height - h)
58
+ left = random.randint(0, self.width - w)
59
+
60
+ num_masked = mask[top : top + h, left : left + w].sum()
61
+ # Overlap
62
+ if 0 < h * w - num_masked <= max_mask_patches:
63
+ for i in range(top, top + h):
64
+ for j in range(left, left + w):
65
+ if mask[i, j] == 0:
66
+ mask[i, j] = 1
67
+ delta += 1
68
+
69
+ if delta > 0:
70
+ break
71
+ return delta
72
+
73
+ def __call__(self, num_masking_patches=0):
74
+ mask = np.zeros(shape=self.get_shape(), dtype=bool)
75
+ mask_count = 0
76
+ while mask_count < num_masking_patches:
77
+ max_mask_patches = num_masking_patches - mask_count
78
+ max_mask_patches = min(max_mask_patches, self.max_num_patches)
79
+
80
+ delta = self._mask(mask, max_mask_patches)
81
+ if delta == 0:
82
+ break
83
+ else:
84
+ mask_count += delta
85
+
86
+ return mask
app/dinov2/data/samplers.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import itertools
7
+ from typing import Any, Optional
8
+ import warnings
9
+
10
+ import numpy as np
11
+ import torch
12
+ from torch.utils.data.sampler import Sampler
13
+
14
+ import dinov2.distributed as distributed
15
+
16
+
17
+ class EpochSampler(Sampler):
18
+ def __init__(
19
+ self,
20
+ *,
21
+ size: int,
22
+ sample_count: int,
23
+ shuffle: bool = False,
24
+ seed: int = 0,
25
+ start: Optional[int] = None,
26
+ step: Optional[int] = None,
27
+ ):
28
+ self._size = size
29
+ self._sample_count = sample_count
30
+ self._shuffle = shuffle
31
+ self._seed = seed
32
+ self._start = distributed.get_global_rank() if start is None else start
33
+ self._step = distributed.get_global_size() if step is None else step
34
+ self._epoch = 0
35
+
36
+ def __iter__(self):
37
+ count = (self._size + self._sample_count - 1) // self._sample_count
38
+ tiled_indices = np.tile(np.arange(self._sample_count), count)
39
+ if self._shuffle:
40
+ seed = self._seed * self._epoch if self._seed != 0 else self._epoch
41
+ rng = np.random.default_rng(seed)
42
+ iterable = rng.choice(tiled_indices, self._size, replace=False)
43
+ else:
44
+ iterable = tiled_indices[: self._size]
45
+
46
+ yield from itertools.islice(iterable, self._start, None, self._step)
47
+
48
+ def __len__(self):
49
+ return (self._size - self._start + self._step - 1) // self._step
50
+
51
+ def set_epoch(self, epoch):
52
+ self._epoch = epoch
53
+
54
+
55
+ def _get_numpy_dtype(size: int) -> Any:
56
+ return np.int32 if size <= 2**31 else np.int64
57
+
58
+
59
+ def _get_torch_dtype(size: int) -> Any:
60
+ return torch.int32 if size <= 2**31 else torch.int64
61
+
62
+
63
+ def _generate_randperm_indices(*, size: int, generator: torch.Generator):
64
+ """Generate the indices of a random permutation."""
65
+ dtype = _get_torch_dtype(size)
66
+ # This is actually matching PyTorch's CPU implementation, see: https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorFactories.cpp#L900-L921
67
+ perm = torch.arange(size, dtype=dtype)
68
+ for i in range(size):
69
+ j = torch.randint(i, size, size=(1,), generator=generator).item()
70
+
71
+ # Always swap even if no-op
72
+ value = perm[j].item()
73
+ perm[j] = perm[i].item()
74
+ perm[i] = value
75
+ yield value
76
+
77
+
78
+ class InfiniteSampler(Sampler):
79
+ def __init__(
80
+ self,
81
+ *,
82
+ sample_count: int,
83
+ shuffle: bool = False,
84
+ seed: int = 0,
85
+ start: Optional[int] = None,
86
+ step: Optional[int] = None,
87
+ advance: int = 0,
88
+ ):
89
+ self._sample_count = sample_count
90
+ self._seed = seed
91
+ self._shuffle = shuffle
92
+ self._start = distributed.get_global_rank() if start is None else start
93
+ self._step = distributed.get_global_size() if step is None else step
94
+ self._advance = advance
95
+
96
+ def __iter__(self):
97
+ if self._shuffle:
98
+ iterator = self._shuffled_iterator()
99
+ else:
100
+ iterator = self._iterator()
101
+
102
+ yield from itertools.islice(iterator, self._advance, None)
103
+
104
+ def _iterator(self):
105
+ assert not self._shuffle
106
+
107
+ while True:
108
+ iterable = range(self._sample_count)
109
+ yield from itertools.islice(iterable, self._start, None, self._step)
110
+
111
+ def _shuffled_iterator(self):
112
+ assert self._shuffle
113
+
114
+ # Instantiate a generator here (rather than in the ctor) to keep the class
115
+ # picklable (requirement of mp.spawn)
116
+ generator = torch.Generator().manual_seed(self._seed)
117
+
118
+ while True:
119
+ iterable = _generate_randperm_indices(size=self._sample_count, generator=generator)
120
+ yield from itertools.islice(iterable, self._start, None, self._step)
121
+
122
+
123
+ # The following function is somewhat equivalent to _new_shuffle_tensor_slice below,
124
+ # but avoids a full in-place random permutation generation.
125
+ def _shuffle_tensor_slice(
126
+ *, tensor: torch.Tensor, start: int = 0, step: int = 1, generator: torch.Generator
127
+ ) -> np.ndarray:
128
+ stop = len(tensor)
129
+ count = stop // step
130
+ drop_count = stop - step * count
131
+ if drop_count:
132
+ warnings.warn(f"# of dropped samples: {drop_count}")
133
+
134
+ dtype = _get_numpy_dtype(stop)
135
+ result = np.empty(count, dtype=dtype)
136
+
137
+ for i in range(count):
138
+ j = torch.randint(0, i + 1, size=(1,), generator=generator).item() if i > 0 else 0
139
+
140
+ result[i] = result[j]
141
+ result[j] = tensor[start + i * step].item()
142
+
143
+ return result
144
+
145
+
146
+ def _new_shuffle_tensor_slice(
147
+ *, tensor: torch.Tensor, start: int = 0, step: int = 1, generator: torch.Generator
148
+ ) -> np.ndarray:
149
+ stop = len(tensor)
150
+ count = stop // step
151
+ dtype = torch.int64 # Needed for using randperm result as indices
152
+ count = stop // step
153
+ drop_count = stop - step * count
154
+ if drop_count:
155
+ warnings.warn(f"# of dropped samples: {drop_count}")
156
+ indices = torch.randperm(count, dtype=dtype, generator=generator)
157
+ return tensor[start::step][indices].numpy()
158
+
159
+
160
+ def _make_seed(seed: int, start: int, iter_count: int) -> int:
161
+ # NOTE: Tried a few variants (including iter_count << 32), this one worked best.
162
+ return seed + start + (iter_count << 24)
163
+
164
+
165
+ class ShardedInfiniteSampler(Sampler):
166
+ def __init__(
167
+ self,
168
+ *,
169
+ sample_count: int,
170
+ shuffle: bool = False,
171
+ seed: int = 0,
172
+ start: Optional[int] = None,
173
+ step: Optional[int] = None,
174
+ advance: int = 0,
175
+ use_new_shuffle_tensor_slice: bool = False,
176
+ ):
177
+ self._sample_count = sample_count
178
+ self._seed = seed
179
+ self._shuffle = shuffle
180
+ self._start = distributed.get_global_rank() if start is None else start
181
+ self._step = distributed.get_global_size() if step is None else step
182
+ self._advance = advance
183
+ self._iter_count = 0
184
+ self._shuffle_tensor_slice_fn = (
185
+ _new_shuffle_tensor_slice if use_new_shuffle_tensor_slice else _shuffle_tensor_slice
186
+ )
187
+
188
+ def __iter__(self):
189
+ iter_count = self._advance // self._sample_count
190
+ if iter_count > 0:
191
+ self._advance -= iter_count * self._sample_count
192
+ self._iter_count += iter_count
193
+
194
+ if self._shuffle:
195
+ iterator = self._shuffled_iterator()
196
+ else:
197
+ iterator = self._iterator()
198
+
199
+ yield from itertools.islice(iterator, self._advance, None)
200
+
201
+ def _iterator(self):
202
+ assert not self._shuffle
203
+
204
+ while True:
205
+ iterable = range(self._sample_count)
206
+ yield from itertools.islice(iterable, self._start, None, self._step)
207
+
208
+ def _shuffled_iterator(self):
209
+ assert self._shuffle
210
+
211
+ # Instantiate a generator here (rather than in the ctor) to be keep the class
212
+ # picklable (requirement of mp.spawn)
213
+ generator = torch.Generator()
214
+
215
+ # Always shuffle everything first
216
+ generator.manual_seed(self._seed)
217
+ dtype = _get_torch_dtype(self._sample_count)
218
+ perm = torch.randperm(self._sample_count, dtype=dtype, generator=generator)
219
+
220
+ while True:
221
+ # Re-seed on each iteration to allow skipping whole permutations
222
+ seed = _make_seed(self._seed, self._start, self._iter_count)
223
+ generator.manual_seed(seed)
224
+
225
+ iterable = self._shuffle_tensor_slice_fn(
226
+ tensor=perm, start=self._start, step=self._step, generator=generator
227
+ )
228
+ yield from iterable
229
+ self._iter_count += 1
app/dinov2/data/transforms.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ from typing import Sequence
7
+
8
+ import torch
9
+ from torchvision import transforms
10
+
11
+
12
+ class GaussianBlur(transforms.RandomApply):
13
+ """
14
+ Apply Gaussian Blur to the PIL image.
15
+ """
16
+
17
+ def __init__(self, *, p: float = 0.5, radius_min: float = 0.1, radius_max: float = 2.0):
18
+ # NOTE: torchvision is applying 1 - probability to return the original image
19
+ keep_p = 1 - p
20
+ transform = transforms.GaussianBlur(kernel_size=9, sigma=(radius_min, radius_max))
21
+ super().__init__(transforms=[transform], p=keep_p)
22
+
23
+
24
+ class MaybeToTensor(transforms.ToTensor):
25
+ """
26
+ Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor, or keep as is if already a tensor.
27
+ """
28
+
29
+ def __call__(self, pic):
30
+ """
31
+ Args:
32
+ pic (PIL Image, numpy.ndarray or torch.tensor): Image to be converted to tensor.
33
+ Returns:
34
+ Tensor: Converted image.
35
+ """
36
+ if isinstance(pic, torch.Tensor):
37
+ return pic
38
+ return super().__call__(pic)
39
+
40
+
41
+ # Use timm's names
42
+ IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
43
+ IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
44
+
45
+
46
+ def make_normalize_transform(
47
+ mean: Sequence[float] = IMAGENET_DEFAULT_MEAN,
48
+ std: Sequence[float] = IMAGENET_DEFAULT_STD,
49
+ ) -> transforms.Normalize:
50
+ return transforms.Normalize(mean=mean, std=std)
51
+
52
+
53
+ # This roughly matches torchvision's preset for classification training:
54
+ # https://github.com/pytorch/vision/blob/main/references/classification/presets.py#L6-L44
55
+ def make_classification_train_transform(
56
+ *,
57
+ crop_size: int = 224,
58
+ interpolation=transforms.InterpolationMode.BICUBIC,
59
+ hflip_prob: float = 0.5,
60
+ mean: Sequence[float] = IMAGENET_DEFAULT_MEAN,
61
+ std: Sequence[float] = IMAGENET_DEFAULT_STD,
62
+ ):
63
+ transforms_list = [transforms.RandomResizedCrop(crop_size, interpolation=interpolation)]
64
+ if hflip_prob > 0.0:
65
+ transforms_list.append(transforms.RandomHorizontalFlip(hflip_prob))
66
+ transforms_list.extend(
67
+ [
68
+ MaybeToTensor(),
69
+ make_normalize_transform(mean=mean, std=std),
70
+ ]
71
+ )
72
+ return transforms.Compose(transforms_list)
73
+
74
+
75
+ # This matches (roughly) torchvision's preset for classification evaluation:
76
+ # https://github.com/pytorch/vision/blob/main/references/classification/presets.py#L47-L69
77
+ def make_classification_eval_transform(
78
+ *,
79
+ resize_size: int = 256,
80
+ interpolation=transforms.InterpolationMode.BICUBIC,
81
+ crop_size: int = 224,
82
+ mean: Sequence[float] = IMAGENET_DEFAULT_MEAN,
83
+ std: Sequence[float] = IMAGENET_DEFAULT_STD,
84
+ ) -> transforms.Compose:
85
+ transforms_list = [
86
+ transforms.Resize(resize_size, interpolation=interpolation),
87
+ transforms.CenterCrop(crop_size),
88
+ MaybeToTensor(),
89
+ make_normalize_transform(mean=mean, std=std),
90
+ ]
91
+ return transforms.Compose(transforms_list)
app/dinov2/distributed/__init__.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import os
7
+ import random
8
+ import re
9
+ import socket
10
+ from typing import Dict, List
11
+
12
+ import torch
13
+ import torch.distributed as dist
14
+
15
+ _LOCAL_RANK = -1
16
+ _LOCAL_WORLD_SIZE = -1
17
+
18
+
19
+ def is_enabled() -> bool:
20
+ """
21
+ Returns:
22
+ True if distributed training is enabled
23
+ """
24
+ return dist.is_available() and dist.is_initialized()
25
+
26
+
27
+ def get_global_size() -> int:
28
+ """
29
+ Returns:
30
+ The number of processes in the process group
31
+ """
32
+ return dist.get_world_size() if is_enabled() else 1
33
+
34
+
35
+ def get_global_rank() -> int:
36
+ """
37
+ Returns:
38
+ The rank of the current process within the global process group.
39
+ """
40
+ return dist.get_rank() if is_enabled() else 0
41
+
42
+
43
+ def get_local_rank() -> int:
44
+ """
45
+ Returns:
46
+ The rank of the current process within the local (per-machine) process group.
47
+ """
48
+ if not is_enabled():
49
+ return 0
50
+ assert 0 <= _LOCAL_RANK < _LOCAL_WORLD_SIZE
51
+ return _LOCAL_RANK
52
+
53
+
54
+ def get_local_size() -> int:
55
+ """
56
+ Returns:
57
+ The size of the per-machine process group,
58
+ i.e. the number of processes per machine.
59
+ """
60
+ if not is_enabled():
61
+ return 1
62
+ assert 0 <= _LOCAL_RANK < _LOCAL_WORLD_SIZE
63
+ return _LOCAL_WORLD_SIZE
64
+
65
+
66
+ def is_main_process() -> bool:
67
+ """
68
+ Returns:
69
+ True if the current process is the main one.
70
+ """
71
+ return get_global_rank() == 0
72
+
73
+
74
+ def _restrict_print_to_main_process() -> None:
75
+ """
76
+ This function disables printing when not in the main process
77
+ """
78
+ import builtins as __builtin__
79
+
80
+ builtin_print = __builtin__.print
81
+
82
+ def print(*args, **kwargs):
83
+ force = kwargs.pop("force", False)
84
+ if is_main_process() or force:
85
+ builtin_print(*args, **kwargs)
86
+
87
+ __builtin__.print = print
88
+
89
+
90
+ def _get_master_port(seed: int = 0) -> int:
91
+ MIN_MASTER_PORT, MAX_MASTER_PORT = (20_000, 60_000)
92
+
93
+ master_port_str = os.environ.get("MASTER_PORT")
94
+ if master_port_str is None:
95
+ rng = random.Random(seed)
96
+ return rng.randint(MIN_MASTER_PORT, MAX_MASTER_PORT)
97
+
98
+ return int(master_port_str)
99
+
100
+
101
+ def _get_available_port() -> int:
102
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
103
+ # A "" host address means INADDR_ANY i.e. binding to all interfaces.
104
+ # Note this is not compatible with IPv6.
105
+ s.bind(("", 0))
106
+ port = s.getsockname()[1]
107
+ return port
108
+
109
+
110
+ _TORCH_DISTRIBUTED_ENV_VARS = (
111
+ "MASTER_ADDR",
112
+ "MASTER_PORT",
113
+ "RANK",
114
+ "WORLD_SIZE",
115
+ "LOCAL_RANK",
116
+ "LOCAL_WORLD_SIZE",
117
+ )
118
+
119
+
120
+ def _collect_env_vars() -> Dict[str, str]:
121
+ return {env_var: os.environ[env_var] for env_var in _TORCH_DISTRIBUTED_ENV_VARS if env_var in os.environ}
122
+
123
+
124
+ def _is_slurm_job_process() -> bool:
125
+ return "SLURM_JOB_ID" in os.environ
126
+
127
+
128
+ def _parse_slurm_node_list(s: str) -> List[str]:
129
+ nodes = []
130
+ # Extract "hostname", "hostname[1-2,3,4-5]," substrings
131
+ p = re.compile(r"(([^\[]+)(?:\[([^\]]+)\])?),?")
132
+ for m in p.finditer(s):
133
+ prefix, suffixes = s[m.start(2) : m.end(2)], s[m.start(3) : m.end(3)]
134
+ for suffix in suffixes.split(","):
135
+ span = suffix.split("-")
136
+ if len(span) == 1:
137
+ nodes.append(prefix + suffix)
138
+ else:
139
+ width = len(span[0])
140
+ start, end = int(span[0]), int(span[1]) + 1
141
+ nodes.extend([prefix + f"{i:0{width}}" for i in range(start, end)])
142
+ return nodes
143
+
144
+
145
+ def _check_env_variable(key: str, new_value: str):
146
+ # Only check for difference with preset environment variables
147
+ if key in os.environ and os.environ[key] != new_value:
148
+ raise RuntimeError(f"Cannot export environment variables as {key} is already set")
149
+
150
+
151
+ class _TorchDistributedEnvironment:
152
+ def __init__(self):
153
+ self.master_addr = "127.0.0.1"
154
+ self.master_port = 0
155
+ self.rank = -1
156
+ self.world_size = -1
157
+ self.local_rank = -1
158
+ self.local_world_size = -1
159
+
160
+ if _is_slurm_job_process():
161
+ return self._set_from_slurm_env()
162
+
163
+ env_vars = _collect_env_vars()
164
+ if not env_vars:
165
+ # Environment is not set
166
+ pass
167
+ elif len(env_vars) == len(_TORCH_DISTRIBUTED_ENV_VARS):
168
+ # Environment is fully set
169
+ return self._set_from_preset_env()
170
+ else:
171
+ # Environment is partially set
172
+ collected_env_vars = ", ".join(env_vars.keys())
173
+ raise RuntimeError(f"Partially set environment: {collected_env_vars}")
174
+
175
+ if torch.cuda.device_count() > 0:
176
+ return self._set_from_local()
177
+
178
+ raise RuntimeError("Can't initialize PyTorch distributed environment")
179
+
180
+ # Slurm job created with sbatch, submitit, etc...
181
+ def _set_from_slurm_env(self):
182
+ # logger.info("Initialization from Slurm environment")
183
+ job_id = int(os.environ["SLURM_JOB_ID"])
184
+ node_count = int(os.environ["SLURM_JOB_NUM_NODES"])
185
+ nodes = _parse_slurm_node_list(os.environ["SLURM_JOB_NODELIST"])
186
+ assert len(nodes) == node_count
187
+
188
+ self.master_addr = nodes[0]
189
+ self.master_port = _get_master_port(seed=job_id)
190
+ self.rank = int(os.environ["SLURM_PROCID"])
191
+ self.world_size = int(os.environ["SLURM_NTASKS"])
192
+ assert self.rank < self.world_size
193
+ self.local_rank = int(os.environ["SLURM_LOCALID"])
194
+ self.local_world_size = self.world_size // node_count
195
+ assert self.local_rank < self.local_world_size
196
+
197
+ # Single node job with preset environment (i.e. torchrun)
198
+ def _set_from_preset_env(self):
199
+ # logger.info("Initialization from preset environment")
200
+ self.master_addr = os.environ["MASTER_ADDR"]
201
+ self.master_port = os.environ["MASTER_PORT"]
202
+ self.rank = int(os.environ["RANK"])
203
+ self.world_size = int(os.environ["WORLD_SIZE"])
204
+ assert self.rank < self.world_size
205
+ self.local_rank = int(os.environ["LOCAL_RANK"])
206
+ self.local_world_size = int(os.environ["LOCAL_WORLD_SIZE"])
207
+ assert self.local_rank < self.local_world_size
208
+
209
+ # Single node and GPU job (i.e. local script run)
210
+ def _set_from_local(self):
211
+ # logger.info("Initialization from local")
212
+ self.master_addr = "127.0.0.1"
213
+ self.master_port = _get_available_port()
214
+ self.rank = 0
215
+ self.world_size = 1
216
+ self.local_rank = 0
217
+ self.local_world_size = 1
218
+
219
+ def export(self, *, overwrite: bool) -> "_TorchDistributedEnvironment":
220
+ # See the "Environment variable initialization" section from
221
+ # https://pytorch.org/docs/stable/distributed.html for the complete list of
222
+ # environment variables required for the env:// initialization method.
223
+ env_vars = {
224
+ "MASTER_ADDR": self.master_addr,
225
+ "MASTER_PORT": str(self.master_port),
226
+ "RANK": str(self.rank),
227
+ "WORLD_SIZE": str(self.world_size),
228
+ "LOCAL_RANK": str(self.local_rank),
229
+ "LOCAL_WORLD_SIZE": str(self.local_world_size),
230
+ }
231
+ if not overwrite:
232
+ for k, v in env_vars.items():
233
+ _check_env_variable(k, v)
234
+
235
+ os.environ.update(env_vars)
236
+ return self
237
+
238
+
239
+ def enable(*, set_cuda_current_device: bool = True, overwrite: bool = False, allow_nccl_timeout: bool = False):
240
+ """Enable distributed mode
241
+
242
+ Args:
243
+ set_cuda_current_device: If True, call torch.cuda.set_device() to set the
244
+ current PyTorch CUDA device to the one matching the local rank.
245
+ overwrite: If True, overwrites already set variables. Else fails.
246
+ """
247
+
248
+ global _LOCAL_RANK, _LOCAL_WORLD_SIZE
249
+ if _LOCAL_RANK >= 0 or _LOCAL_WORLD_SIZE >= 0:
250
+ raise RuntimeError("Distributed mode has already been enabled")
251
+ torch_env = _TorchDistributedEnvironment()
252
+ torch_env.export(overwrite=overwrite)
253
+
254
+ if set_cuda_current_device:
255
+ torch.cuda.set_device(torch_env.local_rank)
256
+
257
+ if allow_nccl_timeout:
258
+ # This allows to use torch distributed timeout in a NCCL backend
259
+ key, value = "NCCL_ASYNC_ERROR_HANDLING", "1"
260
+ if not overwrite:
261
+ _check_env_variable(key, value)
262
+ os.environ[key] = value
263
+
264
+ dist.init_process_group(backend="nccl")
265
+ dist.barrier()
266
+
267
+ # Finalize setup
268
+ _LOCAL_RANK = torch_env.local_rank
269
+ _LOCAL_WORLD_SIZE = torch_env.local_world_size
270
+ _restrict_print_to_main_process()
app/dinov2/eval/.DS_Store ADDED
Binary file (6.15 kB). View file
 
app/dinov2/eval/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
app/dinov2/eval/depth/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
app/dinov2/eval/depth/models/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ from .backbones import * # noqa: F403
7
+ from .builder import BACKBONES, DEPTHER, HEADS, LOSSES, build_backbone, build_depther, build_head, build_loss
8
+ from .decode_heads import * # noqa: F403
9
+ from .depther import * # noqa: F403
10
+ from .losses import * # noqa: F403
app/dinov2/eval/depth/models/backbones/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ from .vision_transformer import DinoVisionTransformer
app/dinov2/eval/depth/models/backbones/vision_transformer.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ from mmcv.runner import BaseModule
7
+
8
+ from ..builder import BACKBONES
9
+
10
+
11
+ @BACKBONES.register_module()
12
+ class DinoVisionTransformer(BaseModule):
13
+ """Vision Transformer."""
14
+
15
+ def __init__(self, *args, **kwargs):
16
+ super().__init__()
app/dinov2/eval/depth/models/builder.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import warnings
7
+
8
+ from mmcv.cnn import MODELS as MMCV_MODELS
9
+ from mmcv.cnn.bricks.registry import ATTENTION as MMCV_ATTENTION
10
+ from mmcv.utils import Registry
11
+
12
+ MODELS = Registry("models", parent=MMCV_MODELS)
13
+ ATTENTION = Registry("attention", parent=MMCV_ATTENTION)
14
+
15
+
16
+ BACKBONES = MODELS
17
+ NECKS = MODELS
18
+ HEADS = MODELS
19
+ LOSSES = MODELS
20
+ DEPTHER = MODELS
21
+
22
+
23
+ def build_backbone(cfg):
24
+ """Build backbone."""
25
+ return BACKBONES.build(cfg)
26
+
27
+
28
+ def build_neck(cfg):
29
+ """Build neck."""
30
+ return NECKS.build(cfg)
31
+
32
+
33
+ def build_head(cfg):
34
+ """Build head."""
35
+ return HEADS.build(cfg)
36
+
37
+
38
+ def build_loss(cfg):
39
+ """Build loss."""
40
+ return LOSSES.build(cfg)
41
+
42
+
43
+ def build_depther(cfg, train_cfg=None, test_cfg=None):
44
+ """Build depther."""
45
+ if train_cfg is not None or test_cfg is not None:
46
+ warnings.warn("train_cfg and test_cfg is deprecated, " "please specify them in model", UserWarning)
47
+ assert cfg.get("train_cfg") is None or train_cfg is None, "train_cfg specified in both outer field and model field "
48
+ assert cfg.get("test_cfg") is None or test_cfg is None, "test_cfg specified in both outer field and model field "
49
+ return DEPTHER.build(cfg, default_args=dict(train_cfg=train_cfg, test_cfg=test_cfg))
app/dinov2/eval/depth/models/decode_heads/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ from .dpt_head import DPTHead
7
+ from .linear_head import BNHead
app/dinov2/eval/depth/models/decode_heads/decode_head.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import copy
7
+ from abc import ABCMeta, abstractmethod
8
+
9
+ import mmcv
10
+ import numpy as np
11
+ import torch
12
+ import torch.nn as nn
13
+ from mmcv.runner import BaseModule, auto_fp16, force_fp32
14
+
15
+ from ...ops import resize
16
+ from ..builder import build_loss
17
+
18
+
19
+ class DepthBaseDecodeHead(BaseModule, metaclass=ABCMeta):
20
+ """Base class for BaseDecodeHead.
21
+
22
+ Args:
23
+ in_channels (List): Input channels.
24
+ channels (int): Channels after modules, before conv_depth.
25
+ conv_cfg (dict|None): Config of conv layers. Default: None.
26
+ act_cfg (dict): Config of activation layers.
27
+ Default: dict(type='ReLU')
28
+ loss_decode (dict): Config of decode loss.
29
+ Default: dict(type='SigLoss').
30
+ sampler (dict|None): The config of depth map sampler.
31
+ Default: None.
32
+ align_corners (bool): align_corners argument of F.interpolate.
33
+ Default: False.
34
+ min_depth (int): Min depth in dataset setting.
35
+ Default: 1e-3.
36
+ max_depth (int): Max depth in dataset setting.
37
+ Default: None.
38
+ norm_cfg (dict|None): Config of norm layers.
39
+ Default: None.
40
+ classify (bool): Whether predict depth in a cls.-reg. manner.
41
+ Default: False.
42
+ n_bins (int): The number of bins used in cls. step.
43
+ Default: 256.
44
+ bins_strategy (str): The discrete strategy used in cls. step.
45
+ Default: 'UD'.
46
+ norm_strategy (str): The norm strategy on cls. probability
47
+ distribution. Default: 'linear'
48
+ scale_up (str): Whether predict depth in a scale-up manner.
49
+ Default: False.
50
+ """
51
+
52
+ def __init__(
53
+ self,
54
+ in_channels,
55
+ channels=96,
56
+ conv_cfg=None,
57
+ act_cfg=dict(type="ReLU"),
58
+ loss_decode=dict(type="SigLoss", valid_mask=True, loss_weight=10),
59
+ sampler=None,
60
+ align_corners=False,
61
+ min_depth=1e-3,
62
+ max_depth=None,
63
+ norm_cfg=None,
64
+ classify=False,
65
+ n_bins=256,
66
+ bins_strategy="UD",
67
+ norm_strategy="linear",
68
+ scale_up=False,
69
+ ):
70
+ super(DepthBaseDecodeHead, self).__init__()
71
+
72
+ self.in_channels = in_channels
73
+ self.channels = channels
74
+ self.conv_cfg = conv_cfg
75
+ self.act_cfg = act_cfg
76
+ if isinstance(loss_decode, dict):
77
+ self.loss_decode = build_loss(loss_decode)
78
+ elif isinstance(loss_decode, (list, tuple)):
79
+ self.loss_decode = nn.ModuleList()
80
+ for loss in loss_decode:
81
+ self.loss_decode.append(build_loss(loss))
82
+ self.align_corners = align_corners
83
+ self.min_depth = min_depth
84
+ self.max_depth = max_depth
85
+ self.norm_cfg = norm_cfg
86
+ self.classify = classify
87
+ self.n_bins = n_bins
88
+ self.scale_up = scale_up
89
+
90
+ if self.classify:
91
+ assert bins_strategy in ["UD", "SID"], "Support bins_strategy: UD, SID"
92
+ assert norm_strategy in ["linear", "softmax", "sigmoid"], "Support norm_strategy: linear, softmax, sigmoid"
93
+
94
+ self.bins_strategy = bins_strategy
95
+ self.norm_strategy = norm_strategy
96
+ self.softmax = nn.Softmax(dim=1)
97
+ self.conv_depth = nn.Conv2d(channels, n_bins, kernel_size=3, padding=1, stride=1)
98
+ else:
99
+ self.conv_depth = nn.Conv2d(channels, 1, kernel_size=3, padding=1, stride=1)
100
+
101
+ self.fp16_enabled = False
102
+ self.relu = nn.ReLU()
103
+ self.sigmoid = nn.Sigmoid()
104
+
105
+ def extra_repr(self):
106
+ """Extra repr."""
107
+ s = f"align_corners={self.align_corners}"
108
+ return s
109
+
110
+ @auto_fp16()
111
+ @abstractmethod
112
+ def forward(self, inputs, img_metas):
113
+ """Placeholder of forward function."""
114
+ pass
115
+
116
+ def forward_train(self, img, inputs, img_metas, depth_gt, train_cfg):
117
+ """Forward function for training.
118
+ Args:
119
+ inputs (list[Tensor]): List of multi-level img features.
120
+ img_metas (list[dict]): List of image info dict where each dict
121
+ has: 'img_shape', 'scale_factor', 'flip', and may also contain
122
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
123
+ For details on the values of these keys see
124
+ `depth/datasets/pipelines/formatting.py:Collect`.
125
+ depth_gt (Tensor): GT depth
126
+ train_cfg (dict): The training config.
127
+
128
+ Returns:
129
+ dict[str, Tensor]: a dictionary of loss components
130
+ """
131
+ depth_pred = self.forward(inputs, img_metas)
132
+ losses = self.losses(depth_pred, depth_gt)
133
+
134
+ log_imgs = self.log_images(img[0], depth_pred[0], depth_gt[0], img_metas[0])
135
+ losses.update(**log_imgs)
136
+
137
+ return losses
138
+
139
+ def forward_test(self, inputs, img_metas, test_cfg):
140
+ """Forward function for testing.
141
+ Args:
142
+ inputs (list[Tensor]): List of multi-level img features.
143
+ img_metas (list[dict]): List of image info dict where each dict
144
+ has: 'img_shape', 'scale_factor', 'flip', and may also contain
145
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
146
+ For details on the values of these keys see
147
+ `depth/datasets/pipelines/formatting.py:Collect`.
148
+ test_cfg (dict): The testing config.
149
+
150
+ Returns:
151
+ Tensor: Output depth map.
152
+ """
153
+ return self.forward(inputs, img_metas)
154
+
155
+ def depth_pred(self, feat):
156
+ """Prediction each pixel."""
157
+ if self.classify:
158
+ logit = self.conv_depth(feat)
159
+
160
+ if self.bins_strategy == "UD":
161
+ bins = torch.linspace(self.min_depth, self.max_depth, self.n_bins, device=feat.device)
162
+ elif self.bins_strategy == "SID":
163
+ bins = torch.logspace(self.min_depth, self.max_depth, self.n_bins, device=feat.device)
164
+
165
+ # following Adabins, default linear
166
+ if self.norm_strategy == "linear":
167
+ logit = torch.relu(logit)
168
+ eps = 0.1
169
+ logit = logit + eps
170
+ logit = logit / logit.sum(dim=1, keepdim=True)
171
+ elif self.norm_strategy == "softmax":
172
+ logit = torch.softmax(logit, dim=1)
173
+ elif self.norm_strategy == "sigmoid":
174
+ logit = torch.sigmoid(logit)
175
+ logit = logit / logit.sum(dim=1, keepdim=True)
176
+
177
+ output = torch.einsum("ikmn,k->imn", [logit, bins]).unsqueeze(dim=1)
178
+
179
+ else:
180
+ if self.scale_up:
181
+ output = self.sigmoid(self.conv_depth(feat)) * self.max_depth
182
+ else:
183
+ output = self.relu(self.conv_depth(feat)) + self.min_depth
184
+ return output
185
+
186
+ @force_fp32(apply_to=("depth_pred",))
187
+ def losses(self, depth_pred, depth_gt):
188
+ """Compute depth loss."""
189
+ loss = dict()
190
+ depth_pred = resize(
191
+ input=depth_pred, size=depth_gt.shape[2:], mode="bilinear", align_corners=self.align_corners, warning=False
192
+ )
193
+ if not isinstance(self.loss_decode, nn.ModuleList):
194
+ losses_decode = [self.loss_decode]
195
+ else:
196
+ losses_decode = self.loss_decode
197
+ for loss_decode in losses_decode:
198
+ if loss_decode.loss_name not in loss:
199
+ loss[loss_decode.loss_name] = loss_decode(depth_pred, depth_gt)
200
+ else:
201
+ loss[loss_decode.loss_name] += loss_decode(depth_pred, depth_gt)
202
+ return loss
203
+
204
+ def log_images(self, img_path, depth_pred, depth_gt, img_meta):
205
+ show_img = copy.deepcopy(img_path.detach().cpu().permute(1, 2, 0))
206
+ show_img = show_img.numpy().astype(np.float32)
207
+ show_img = mmcv.imdenormalize(
208
+ show_img,
209
+ img_meta["img_norm_cfg"]["mean"],
210
+ img_meta["img_norm_cfg"]["std"],
211
+ img_meta["img_norm_cfg"]["to_rgb"],
212
+ )
213
+ show_img = np.clip(show_img, 0, 255)
214
+ show_img = show_img.astype(np.uint8)
215
+ show_img = show_img[:, :, ::-1]
216
+ show_img = show_img.transpose(0, 2, 1)
217
+ show_img = show_img.transpose(1, 0, 2)
218
+
219
+ depth_pred = depth_pred / torch.max(depth_pred)
220
+ depth_gt = depth_gt / torch.max(depth_gt)
221
+
222
+ depth_pred_color = copy.deepcopy(depth_pred.detach().cpu())
223
+ depth_gt_color = copy.deepcopy(depth_gt.detach().cpu())
224
+
225
+ return {"img_rgb": show_img, "img_depth_pred": depth_pred_color, "img_depth_gt": depth_gt_color}
app/dinov2/eval/depth/models/decode_heads/dpt_head.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import math
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ from mmcv.cnn import ConvModule, Linear, build_activation_layer
11
+ from mmcv.runner import BaseModule
12
+
13
+ from ...ops import resize
14
+ from ..builder import HEADS
15
+ from .decode_head import DepthBaseDecodeHead
16
+
17
+
18
+ class Interpolate(nn.Module):
19
+ def __init__(self, scale_factor, mode, align_corners=False):
20
+ super(Interpolate, self).__init__()
21
+ self.interp = nn.functional.interpolate
22
+ self.scale_factor = scale_factor
23
+ self.mode = mode
24
+ self.align_corners = align_corners
25
+
26
+ def forward(self, x):
27
+ x = self.interp(x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners)
28
+ return x
29
+
30
+
31
+ class HeadDepth(nn.Module):
32
+ def __init__(self, features):
33
+ super(HeadDepth, self).__init__()
34
+ self.head = nn.Sequential(
35
+ nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
36
+ Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
37
+ nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
38
+ nn.ReLU(),
39
+ nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
40
+ )
41
+
42
+ def forward(self, x):
43
+ x = self.head(x)
44
+ return x
45
+
46
+
47
+ class ReassembleBlocks(BaseModule):
48
+ """ViTPostProcessBlock, process cls_token in ViT backbone output and
49
+ rearrange the feature vector to feature map.
50
+ Args:
51
+ in_channels (int): ViT feature channels. Default: 768.
52
+ out_channels (List): output channels of each stage.
53
+ Default: [96, 192, 384, 768].
54
+ readout_type (str): Type of readout operation. Default: 'ignore'.
55
+ patch_size (int): The patch size. Default: 16.
56
+ init_cfg (dict, optional): Initialization config dict. Default: None.
57
+ """
58
+
59
+ def __init__(
60
+ self, in_channels=768, out_channels=[96, 192, 384, 768], readout_type="ignore", patch_size=16, init_cfg=None
61
+ ):
62
+ super(ReassembleBlocks, self).__init__(init_cfg)
63
+
64
+ assert readout_type in ["ignore", "add", "project"]
65
+ self.readout_type = readout_type
66
+ self.patch_size = patch_size
67
+
68
+ self.projects = nn.ModuleList(
69
+ [
70
+ ConvModule(
71
+ in_channels=in_channels,
72
+ out_channels=out_channel,
73
+ kernel_size=1,
74
+ act_cfg=None,
75
+ )
76
+ for out_channel in out_channels
77
+ ]
78
+ )
79
+
80
+ self.resize_layers = nn.ModuleList(
81
+ [
82
+ nn.ConvTranspose2d(
83
+ in_channels=out_channels[0], out_channels=out_channels[0], kernel_size=4, stride=4, padding=0
84
+ ),
85
+ nn.ConvTranspose2d(
86
+ in_channels=out_channels[1], out_channels=out_channels[1], kernel_size=2, stride=2, padding=0
87
+ ),
88
+ nn.Identity(),
89
+ nn.Conv2d(
90
+ in_channels=out_channels[3], out_channels=out_channels[3], kernel_size=3, stride=2, padding=1
91
+ ),
92
+ ]
93
+ )
94
+ if self.readout_type == "project":
95
+ self.readout_projects = nn.ModuleList()
96
+ for _ in range(len(self.projects)):
97
+ self.readout_projects.append(
98
+ nn.Sequential(Linear(2 * in_channels, in_channels), build_activation_layer(dict(type="GELU")))
99
+ )
100
+
101
+ def forward(self, inputs):
102
+ assert isinstance(inputs, list)
103
+ out = []
104
+ for i, x in enumerate(inputs):
105
+ assert len(x) == 2
106
+ x, cls_token = x[0], x[1]
107
+ feature_shape = x.shape
108
+ if self.readout_type == "project":
109
+ x = x.flatten(2).permute((0, 2, 1))
110
+ readout = cls_token.unsqueeze(1).expand_as(x)
111
+ x = self.readout_projects[i](torch.cat((x, readout), -1))
112
+ x = x.permute(0, 2, 1).reshape(feature_shape)
113
+ elif self.readout_type == "add":
114
+ x = x.flatten(2) + cls_token.unsqueeze(-1)
115
+ x = x.reshape(feature_shape)
116
+ else:
117
+ pass
118
+ x = self.projects[i](x)
119
+ x = self.resize_layers[i](x)
120
+ out.append(x)
121
+ return out
122
+
123
+
124
+ class PreActResidualConvUnit(BaseModule):
125
+ """ResidualConvUnit, pre-activate residual unit.
126
+ Args:
127
+ in_channels (int): number of channels in the input feature map.
128
+ act_cfg (dict): dictionary to construct and config activation layer.
129
+ norm_cfg (dict): dictionary to construct and config norm layer.
130
+ stride (int): stride of the first block. Default: 1
131
+ dilation (int): dilation rate for convs layers. Default: 1.
132
+ init_cfg (dict, optional): Initialization config dict. Default: None.
133
+ """
134
+
135
+ def __init__(self, in_channels, act_cfg, norm_cfg, stride=1, dilation=1, init_cfg=None):
136
+ super(PreActResidualConvUnit, self).__init__(init_cfg)
137
+
138
+ self.conv1 = ConvModule(
139
+ in_channels,
140
+ in_channels,
141
+ 3,
142
+ stride=stride,
143
+ padding=dilation,
144
+ dilation=dilation,
145
+ norm_cfg=norm_cfg,
146
+ act_cfg=act_cfg,
147
+ bias=False,
148
+ order=("act", "conv", "norm"),
149
+ )
150
+
151
+ self.conv2 = ConvModule(
152
+ in_channels,
153
+ in_channels,
154
+ 3,
155
+ padding=1,
156
+ norm_cfg=norm_cfg,
157
+ act_cfg=act_cfg,
158
+ bias=False,
159
+ order=("act", "conv", "norm"),
160
+ )
161
+
162
+ def forward(self, inputs):
163
+ inputs_ = inputs.clone()
164
+ x = self.conv1(inputs)
165
+ x = self.conv2(x)
166
+ return x + inputs_
167
+
168
+
169
+ class FeatureFusionBlock(BaseModule):
170
+ """FeatureFusionBlock, merge feature map from different stages.
171
+ Args:
172
+ in_channels (int): Input channels.
173
+ act_cfg (dict): The activation config for ResidualConvUnit.
174
+ norm_cfg (dict): Config dict for normalization layer.
175
+ expand (bool): Whether expand the channels in post process block.
176
+ Default: False.
177
+ align_corners (bool): align_corner setting for bilinear upsample.
178
+ Default: True.
179
+ init_cfg (dict, optional): Initialization config dict. Default: None.
180
+ """
181
+
182
+ def __init__(self, in_channels, act_cfg, norm_cfg, expand=False, align_corners=True, init_cfg=None):
183
+ super(FeatureFusionBlock, self).__init__(init_cfg)
184
+
185
+ self.in_channels = in_channels
186
+ self.expand = expand
187
+ self.align_corners = align_corners
188
+
189
+ self.out_channels = in_channels
190
+ if self.expand:
191
+ self.out_channels = in_channels // 2
192
+
193
+ self.project = ConvModule(self.in_channels, self.out_channels, kernel_size=1, act_cfg=None, bias=True)
194
+
195
+ self.res_conv_unit1 = PreActResidualConvUnit(in_channels=self.in_channels, act_cfg=act_cfg, norm_cfg=norm_cfg)
196
+ self.res_conv_unit2 = PreActResidualConvUnit(in_channels=self.in_channels, act_cfg=act_cfg, norm_cfg=norm_cfg)
197
+
198
+ def forward(self, *inputs):
199
+ x = inputs[0]
200
+ if len(inputs) == 2:
201
+ if x.shape != inputs[1].shape:
202
+ res = resize(inputs[1], size=(x.shape[2], x.shape[3]), mode="bilinear", align_corners=False)
203
+ else:
204
+ res = inputs[1]
205
+ x = x + self.res_conv_unit1(res)
206
+ x = self.res_conv_unit2(x)
207
+ x = resize(x, scale_factor=2, mode="bilinear", align_corners=self.align_corners)
208
+ x = self.project(x)
209
+ return x
210
+
211
+
212
+ @HEADS.register_module()
213
+ class DPTHead(DepthBaseDecodeHead):
214
+ """Vision Transformers for Dense Prediction.
215
+ This head is implemented of `DPT <https://arxiv.org/abs/2103.13413>`_.
216
+ Args:
217
+ embed_dims (int): The embed dimension of the ViT backbone.
218
+ Default: 768.
219
+ post_process_channels (List): Out channels of post process conv
220
+ layers. Default: [96, 192, 384, 768].
221
+ readout_type (str): Type of readout operation. Default: 'ignore'.
222
+ patch_size (int): The patch size. Default: 16.
223
+ expand_channels (bool): Whether expand the channels in post process
224
+ block. Default: False.
225
+ """
226
+
227
+ def __init__(
228
+ self,
229
+ embed_dims=768,
230
+ post_process_channels=[96, 192, 384, 768],
231
+ readout_type="ignore",
232
+ patch_size=16,
233
+ expand_channels=False,
234
+ **kwargs
235
+ ):
236
+ super(DPTHead, self).__init__(**kwargs)
237
+
238
+ self.in_channels = self.in_channels
239
+ self.expand_channels = expand_channels
240
+ self.reassemble_blocks = ReassembleBlocks(embed_dims, post_process_channels, readout_type, patch_size)
241
+
242
+ self.post_process_channels = [
243
+ channel * math.pow(2, i) if expand_channels else channel for i, channel in enumerate(post_process_channels)
244
+ ]
245
+ self.convs = nn.ModuleList()
246
+ for channel in self.post_process_channels:
247
+ self.convs.append(ConvModule(channel, self.channels, kernel_size=3, padding=1, act_cfg=None, bias=False))
248
+ self.fusion_blocks = nn.ModuleList()
249
+ for _ in range(len(self.convs)):
250
+ self.fusion_blocks.append(FeatureFusionBlock(self.channels, self.act_cfg, self.norm_cfg))
251
+ self.fusion_blocks[0].res_conv_unit1 = None
252
+ self.project = ConvModule(self.channels, self.channels, kernel_size=3, padding=1, norm_cfg=self.norm_cfg)
253
+ self.num_fusion_blocks = len(self.fusion_blocks)
254
+ self.num_reassemble_blocks = len(self.reassemble_blocks.resize_layers)
255
+ self.num_post_process_channels = len(self.post_process_channels)
256
+ assert self.num_fusion_blocks == self.num_reassemble_blocks
257
+ assert self.num_reassemble_blocks == self.num_post_process_channels
258
+ self.conv_depth = HeadDepth(self.channels)
259
+
260
+ def forward(self, inputs, img_metas):
261
+ assert len(inputs) == self.num_reassemble_blocks
262
+ x = [inp for inp in inputs]
263
+ x = self.reassemble_blocks(x)
264
+ x = [self.convs[i](feature) for i, feature in enumerate(x)]
265
+ out = self.fusion_blocks[0](x[-1])
266
+ for i in range(1, len(self.fusion_blocks)):
267
+ out = self.fusion_blocks[i](out, x[-(i + 1)])
268
+ out = self.project(out)
269
+ out = self.depth_pred(out)
270
+ return out
app/dinov2/eval/depth/models/decode_heads/linear_head.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+ from ...ops import resize
10
+ from ..builder import HEADS
11
+ from .decode_head import DepthBaseDecodeHead
12
+
13
+
14
+ @HEADS.register_module()
15
+ class BNHead(DepthBaseDecodeHead):
16
+ """Just a batchnorm."""
17
+
18
+ def __init__(self, input_transform="resize_concat", in_index=(0, 1, 2, 3), upsample=1, **kwargs):
19
+ super().__init__(**kwargs)
20
+ self.input_transform = input_transform
21
+ self.in_index = in_index
22
+ self.upsample = upsample
23
+ # self.bn = nn.SyncBatchNorm(self.in_channels)
24
+ if self.classify:
25
+ self.conv_depth = nn.Conv2d(self.channels, self.n_bins, kernel_size=1, padding=0, stride=1)
26
+ else:
27
+ self.conv_depth = nn.Conv2d(self.channels, 1, kernel_size=1, padding=0, stride=1)
28
+
29
+ def _transform_inputs(self, inputs):
30
+ """Transform inputs for decoder.
31
+ Args:
32
+ inputs (list[Tensor]): List of multi-level img features.
33
+ Returns:
34
+ Tensor: The transformed inputs
35
+ """
36
+
37
+ if "concat" in self.input_transform:
38
+ inputs = [inputs[i] for i in self.in_index]
39
+ if "resize" in self.input_transform:
40
+ inputs = [
41
+ resize(
42
+ input=x,
43
+ size=[s * self.upsample for s in inputs[0].shape[2:]],
44
+ mode="bilinear",
45
+ align_corners=self.align_corners,
46
+ )
47
+ for x in inputs
48
+ ]
49
+ inputs = torch.cat(inputs, dim=1)
50
+ elif self.input_transform == "multiple_select":
51
+ inputs = [inputs[i] for i in self.in_index]
52
+ else:
53
+ inputs = inputs[self.in_index]
54
+
55
+ return inputs
56
+
57
+ def _forward_feature(self, inputs, img_metas=None, **kwargs):
58
+ """Forward function for feature maps before classifying each pixel with
59
+ ``self.cls_seg`` fc.
60
+ Args:
61
+ inputs (list[Tensor]): List of multi-level img features.
62
+ Returns:
63
+ feats (Tensor): A tensor of shape (batch_size, self.channels,
64
+ H, W) which is feature map for last layer of decoder head.
65
+ """
66
+ # accept lists (for cls token)
67
+ inputs = list(inputs)
68
+ for i, x in enumerate(inputs):
69
+ if len(x) == 2:
70
+ x, cls_token = x[0], x[1]
71
+ if len(x.shape) == 2:
72
+ x = x[:, :, None, None]
73
+ cls_token = cls_token[:, :, None, None].expand_as(x)
74
+ inputs[i] = torch.cat((x, cls_token), 1)
75
+ else:
76
+ x = x[0]
77
+ if len(x.shape) == 2:
78
+ x = x[:, :, None, None]
79
+ inputs[i] = x
80
+ x = self._transform_inputs(inputs)
81
+ # feats = self.bn(x)
82
+ return x
83
+
84
+ def forward(self, inputs, img_metas=None, **kwargs):
85
+ """Forward function."""
86
+ output = self._forward_feature(inputs, img_metas=img_metas, **kwargs)
87
+ output = self.depth_pred(output)
88
+
89
+ return output
app/dinov2/eval/depth/models/depther/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ from .base import BaseDepther
7
+ from .encoder_decoder import DepthEncoderDecoder
app/dinov2/eval/depth/models/depther/base.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ from abc import ABCMeta, abstractmethod
7
+ from collections import OrderedDict
8
+
9
+ import torch
10
+ import torch.distributed as dist
11
+ from mmcv.runner import BaseModule, auto_fp16
12
+
13
+
14
+ class BaseDepther(BaseModule, metaclass=ABCMeta):
15
+ """Base class for depther."""
16
+
17
+ def __init__(self, init_cfg=None):
18
+ super(BaseDepther, self).__init__(init_cfg)
19
+ self.fp16_enabled = False
20
+
21
+ @property
22
+ def with_neck(self):
23
+ """bool: whether the depther has neck"""
24
+ return hasattr(self, "neck") and self.neck is not None
25
+
26
+ @property
27
+ def with_auxiliary_head(self):
28
+ """bool: whether the depther has auxiliary head"""
29
+ return hasattr(self, "auxiliary_head") and self.auxiliary_head is not None
30
+
31
+ @property
32
+ def with_decode_head(self):
33
+ """bool: whether the depther has decode head"""
34
+ return hasattr(self, "decode_head") and self.decode_head is not None
35
+
36
+ @abstractmethod
37
+ def extract_feat(self, imgs):
38
+ """Placeholder for extract features from images."""
39
+ pass
40
+
41
+ @abstractmethod
42
+ def encode_decode(self, img, img_metas):
43
+ """Placeholder for encode images with backbone and decode into a
44
+ semantic depth map of the same size as input."""
45
+ pass
46
+
47
+ @abstractmethod
48
+ def forward_train(self, imgs, img_metas, **kwargs):
49
+ """Placeholder for Forward function for training."""
50
+ pass
51
+
52
+ @abstractmethod
53
+ def simple_test(self, img, img_meta, **kwargs):
54
+ """Placeholder for single image test."""
55
+ pass
56
+
57
+ @abstractmethod
58
+ def aug_test(self, imgs, img_metas, **kwargs):
59
+ """Placeholder for augmentation test."""
60
+ pass
61
+
62
+ def forward_test(self, imgs, img_metas, **kwargs):
63
+ """
64
+ Args:
65
+ imgs (List[Tensor]): the outer list indicates test-time
66
+ augmentations and inner Tensor should have a shape NxCxHxW,
67
+ which contains all images in the batch.
68
+ img_metas (List[List[dict]]): the outer list indicates test-time
69
+ augs (multiscale, flip, etc.) and the inner list indicates
70
+ images in a batch.
71
+ """
72
+ for var, name in [(imgs, "imgs"), (img_metas, "img_metas")]:
73
+ if not isinstance(var, list):
74
+ raise TypeError(f"{name} must be a list, but got " f"{type(var)}")
75
+ num_augs = len(imgs)
76
+ if num_augs != len(img_metas):
77
+ raise ValueError(f"num of augmentations ({len(imgs)}) != " f"num of image meta ({len(img_metas)})")
78
+ # all images in the same aug batch all of the same ori_shape and pad
79
+ # shape
80
+ for img_meta in img_metas:
81
+ ori_shapes = [_["ori_shape"] for _ in img_meta]
82
+ assert all(shape == ori_shapes[0] for shape in ori_shapes)
83
+ img_shapes = [_["img_shape"] for _ in img_meta]
84
+ assert all(shape == img_shapes[0] for shape in img_shapes)
85
+ pad_shapes = [_["pad_shape"] for _ in img_meta]
86
+ assert all(shape == pad_shapes[0] for shape in pad_shapes)
87
+
88
+ if num_augs == 1:
89
+ return self.simple_test(imgs[0], img_metas[0], **kwargs)
90
+ else:
91
+ return self.aug_test(imgs, img_metas, **kwargs)
92
+
93
+ @auto_fp16(apply_to=("img",))
94
+ def forward(self, img, img_metas, return_loss=True, **kwargs):
95
+ """Calls either :func:`forward_train` or :func:`forward_test` depending
96
+ on whether ``return_loss`` is ``True``.
97
+
98
+ Note this setting will change the expected inputs. When
99
+ ``return_loss=True``, img and img_meta are single-nested (i.e. Tensor
100
+ and List[dict]), and when ``resturn_loss=False``, img and img_meta
101
+ should be double nested (i.e. List[Tensor], List[List[dict]]), with
102
+ the outer list indicating test time augmentations.
103
+ """
104
+ if return_loss:
105
+ return self.forward_train(img, img_metas, **kwargs)
106
+ else:
107
+ return self.forward_test(img, img_metas, **kwargs)
108
+
109
+ def train_step(self, data_batch, optimizer, **kwargs):
110
+ """The iteration step during training.
111
+
112
+ This method defines an iteration step during training, except for the
113
+ back propagation and optimizer updating, which are done in an optimizer
114
+ hook. Note that in some complicated cases or models, the whole process
115
+ including back propagation and optimizer updating is also defined in
116
+ this method, such as GAN.
117
+
118
+ Args:
119
+ data (dict): The output of dataloader.
120
+ optimizer (:obj:`torch.optim.Optimizer` | dict): The optimizer of
121
+ runner is passed to ``train_step()``. This argument is unused
122
+ and reserved.
123
+
124
+ Returns:
125
+ dict: It should contain at least 3 keys: ``loss``, ``log_vars``,
126
+ ``num_samples``.
127
+ ``loss`` is a tensor for back propagation, which can be a
128
+ weighted sum of multiple losses.
129
+ ``log_vars`` contains all the variables to be sent to the
130
+ logger.
131
+ ``num_samples`` indicates the batch size (when the model is
132
+ DDP, it means the batch size on each GPU), which is used for
133
+ averaging the logs.
134
+ """
135
+ losses = self(**data_batch)
136
+
137
+ # split losses and images
138
+ real_losses = {}
139
+ log_imgs = {}
140
+ for k, v in losses.items():
141
+ if "img" in k:
142
+ log_imgs[k] = v
143
+ else:
144
+ real_losses[k] = v
145
+
146
+ loss, log_vars = self._parse_losses(real_losses)
147
+
148
+ outputs = dict(loss=loss, log_vars=log_vars, num_samples=len(data_batch["img_metas"]), log_imgs=log_imgs)
149
+
150
+ return outputs
151
+
152
+ def val_step(self, data_batch, **kwargs):
153
+ """The iteration step during validation.
154
+
155
+ This method shares the same signature as :func:`train_step`, but used
156
+ during val epochs. Note that the evaluation after training epochs is
157
+ not implemented with this method, but an evaluation hook.
158
+ """
159
+ output = self(**data_batch, **kwargs)
160
+ return output
161
+
162
+ @staticmethod
163
+ def _parse_losses(losses):
164
+ """Parse the raw outputs (losses) of the network.
165
+
166
+ Args:
167
+ losses (dict): Raw output of the network, which usually contain
168
+ losses and other necessary information.
169
+
170
+ Returns:
171
+ tuple[Tensor, dict]: (loss, log_vars), loss is the loss tensor
172
+ which may be a weighted sum of all losses, log_vars contains
173
+ all the variables to be sent to the logger.
174
+ """
175
+ log_vars = OrderedDict()
176
+ for loss_name, loss_value in losses.items():
177
+ if isinstance(loss_value, torch.Tensor):
178
+ log_vars[loss_name] = loss_value.mean()
179
+ elif isinstance(loss_value, list):
180
+ log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value)
181
+ else:
182
+ raise TypeError(f"{loss_name} is not a tensor or list of tensors")
183
+
184
+ loss = sum(_value for _key, _value in log_vars.items() if "loss" in _key)
185
+
186
+ log_vars["loss"] = loss
187
+ for loss_name, loss_value in log_vars.items():
188
+ # reduce loss when distributed training
189
+ if dist.is_available() and dist.is_initialized():
190
+ loss_value = loss_value.data.clone()
191
+ dist.all_reduce(loss_value.div_(dist.get_world_size()))
192
+ log_vars[loss_name] = loss_value.item()
193
+
194
+ return loss, log_vars
app/dinov2/eval/depth/models/depther/encoder_decoder.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+
9
+ from ...models import builder
10
+ from ...models.builder import DEPTHER
11
+ from ...ops import resize
12
+ from .base import BaseDepther
13
+
14
+
15
+ def add_prefix(inputs, prefix):
16
+ """Add prefix for dict.
17
+
18
+ Args:
19
+ inputs (dict): The input dict with str keys.
20
+ prefix (str): The prefix to add.
21
+
22
+ Returns:
23
+
24
+ dict: The dict with keys updated with ``prefix``.
25
+ """
26
+
27
+ outputs = dict()
28
+ for name, value in inputs.items():
29
+ outputs[f"{prefix}.{name}"] = value
30
+
31
+ return outputs
32
+
33
+
34
+ @DEPTHER.register_module()
35
+ class DepthEncoderDecoder(BaseDepther):
36
+ """Encoder Decoder depther.
37
+
38
+ EncoderDecoder typically consists of backbone, (neck) and decode_head.
39
+ """
40
+
41
+ def __init__(self, backbone, decode_head, neck=None, train_cfg=None, test_cfg=None, pretrained=None, init_cfg=None):
42
+ super(DepthEncoderDecoder, self).__init__(init_cfg)
43
+ if pretrained is not None:
44
+ assert backbone.get("pretrained") is None, "both backbone and depther set pretrained weight"
45
+ backbone.pretrained = pretrained
46
+ self.backbone = builder.build_backbone(backbone)
47
+ self._init_decode_head(decode_head)
48
+
49
+ if neck is not None:
50
+ self.neck = builder.build_neck(neck)
51
+
52
+ self.train_cfg = train_cfg
53
+ self.test_cfg = test_cfg
54
+
55
+ assert self.with_decode_head
56
+
57
+ def _init_decode_head(self, decode_head):
58
+ """Initialize ``decode_head``"""
59
+ self.decode_head = builder.build_head(decode_head)
60
+ self.align_corners = self.decode_head.align_corners
61
+
62
+ def extract_feat(self, img):
63
+ """Extract features from images."""
64
+ x = self.backbone(img)
65
+ if self.with_neck:
66
+ x = self.neck(x)
67
+ return x
68
+
69
+ def encode_decode(self, img, img_metas, rescale=True, size=None):
70
+ """Encode images with backbone and decode into a depth estimation
71
+ map of the same size as input."""
72
+ x = self.extract_feat(img)
73
+ out = self._decode_head_forward_test(x, img_metas)
74
+ # crop the pred depth to the certain range.
75
+ out = torch.clamp(out, min=self.decode_head.min_depth, max=self.decode_head.max_depth)
76
+ if rescale:
77
+ if size is None:
78
+ if img_metas is not None:
79
+ size = img_metas[0]["ori_shape"][:2]
80
+ else:
81
+ size = img.shape[2:]
82
+ out = resize(input=out, size=size, mode="bilinear", align_corners=self.align_corners)
83
+ return out
84
+
85
+ def _decode_head_forward_train(self, img, x, img_metas, depth_gt, **kwargs):
86
+ """Run forward function and calculate loss for decode head in
87
+ training."""
88
+ losses = dict()
89
+ loss_decode = self.decode_head.forward_train(img, x, img_metas, depth_gt, self.train_cfg, **kwargs)
90
+ losses.update(add_prefix(loss_decode, "decode"))
91
+ return losses
92
+
93
+ def _decode_head_forward_test(self, x, img_metas):
94
+ """Run forward function and calculate loss for decode head in
95
+ inference."""
96
+ depth_pred = self.decode_head.forward_test(x, img_metas, self.test_cfg)
97
+ return depth_pred
98
+
99
+ def forward_dummy(self, img):
100
+ """Dummy forward function."""
101
+ depth = self.encode_decode(img, None)
102
+
103
+ return depth
104
+
105
+ def forward_train(self, img, img_metas, depth_gt, **kwargs):
106
+ """Forward function for training.
107
+
108
+ Args:
109
+ img (Tensor): Input images.
110
+ img_metas (list[dict]): List of image info dict where each dict
111
+ has: 'img_shape', 'scale_factor', 'flip', and may also contain
112
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
113
+ For details on the values of these keys see
114
+ `depth/datasets/pipelines/formatting.py:Collect`.
115
+ depth_gt (Tensor): Depth gt
116
+ used if the architecture supports depth estimation task.
117
+
118
+ Returns:
119
+ dict[str, Tensor]: a dictionary of loss components
120
+ """
121
+
122
+ x = self.extract_feat(img)
123
+
124
+ losses = dict()
125
+
126
+ # the last of x saves the info from neck
127
+ loss_decode = self._decode_head_forward_train(img, x, img_metas, depth_gt, **kwargs)
128
+
129
+ losses.update(loss_decode)
130
+
131
+ return losses
132
+
133
+ def whole_inference(self, img, img_meta, rescale, size=None):
134
+ """Inference with full image."""
135
+ depth_pred = self.encode_decode(img, img_meta, rescale, size=size)
136
+
137
+ return depth_pred
138
+
139
+ def slide_inference(self, img, img_meta, rescale):
140
+ """Inference by sliding-window with overlap.
141
+
142
+ If h_crop > h_img or w_crop > w_img, the small patch will be used to
143
+ decode without padding.
144
+ """
145
+
146
+ h_stride, w_stride = self.test_cfg.stride
147
+ h_crop, w_crop = self.test_cfg.crop_size
148
+ batch_size, _, h_img, w_img = img.size()
149
+ h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1
150
+ w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1
151
+ preds = img.new_zeros((batch_size, 1, h_img, w_img))
152
+ count_mat = img.new_zeros((batch_size, 1, h_img, w_img))
153
+ for h_idx in range(h_grids):
154
+ for w_idx in range(w_grids):
155
+ y1 = h_idx * h_stride
156
+ x1 = w_idx * w_stride
157
+ y2 = min(y1 + h_crop, h_img)
158
+ x2 = min(x1 + w_crop, w_img)
159
+ y1 = max(y2 - h_crop, 0)
160
+ x1 = max(x2 - w_crop, 0)
161
+ crop_img = img[:, :, y1:y2, x1:x2]
162
+ depth_pred = self.encode_decode(crop_img, img_meta, rescale)
163
+ preds += F.pad(depth_pred, (int(x1), int(preds.shape[3] - x2), int(y1), int(preds.shape[2] - y2)))
164
+
165
+ count_mat[:, :, y1:y2, x1:x2] += 1
166
+ assert (count_mat == 0).sum() == 0
167
+ if torch.onnx.is_in_onnx_export():
168
+ # cast count_mat to constant while exporting to ONNX
169
+ count_mat = torch.from_numpy(count_mat.cpu().detach().numpy()).to(device=img.device)
170
+ preds = preds / count_mat
171
+ return preds
172
+
173
+ def inference(self, img, img_meta, rescale, size=None):
174
+ """Inference with slide/whole style.
175
+
176
+ Args:
177
+ img (Tensor): The input image of shape (N, 3, H, W).
178
+ img_meta (dict): Image info dict where each dict has: 'img_shape',
179
+ 'scale_factor', 'flip', and may also contain
180
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
181
+ For details on the values of these keys see
182
+ `depth/datasets/pipelines/formatting.py:Collect`.
183
+ rescale (bool): Whether rescale back to original shape.
184
+
185
+ Returns:
186
+ Tensor: The output depth map.
187
+ """
188
+
189
+ assert self.test_cfg.mode in ["slide", "whole"]
190
+ ori_shape = img_meta[0]["ori_shape"]
191
+ assert all(_["ori_shape"] == ori_shape for _ in img_meta)
192
+ if self.test_cfg.mode == "slide":
193
+ depth_pred = self.slide_inference(img, img_meta, rescale)
194
+ else:
195
+ depth_pred = self.whole_inference(img, img_meta, rescale, size=size)
196
+ output = depth_pred
197
+ flip = img_meta[0]["flip"]
198
+ if flip:
199
+ flip_direction = img_meta[0]["flip_direction"]
200
+ assert flip_direction in ["horizontal", "vertical"]
201
+ if flip_direction == "horizontal":
202
+ output = output.flip(dims=(3,))
203
+ elif flip_direction == "vertical":
204
+ output = output.flip(dims=(2,))
205
+
206
+ return output
207
+
208
+ def simple_test(self, img, img_meta, rescale=True):
209
+ """Simple test with single image."""
210
+ depth_pred = self.inference(img, img_meta, rescale)
211
+ if torch.onnx.is_in_onnx_export():
212
+ # our inference backend only support 4D output
213
+ depth_pred = depth_pred.unsqueeze(0)
214
+ return depth_pred
215
+ depth_pred = depth_pred.cpu().numpy()
216
+ # unravel batch dim
217
+ depth_pred = list(depth_pred)
218
+ return depth_pred
219
+
220
+ def aug_test(self, imgs, img_metas, rescale=True):
221
+ """Test with augmentations.
222
+
223
+ Only rescale=True is supported.
224
+ """
225
+ # aug_test rescale all imgs back to ori_shape for now
226
+ assert rescale
227
+ # to save memory, we get augmented depth logit inplace
228
+ depth_pred = self.inference(imgs[0], img_metas[0], rescale)
229
+ for i in range(1, len(imgs)):
230
+ cur_depth_pred = self.inference(imgs[i], img_metas[i], rescale, size=depth_pred.shape[-2:])
231
+ depth_pred += cur_depth_pred
232
+ depth_pred /= len(imgs)
233
+ depth_pred = depth_pred.cpu().numpy()
234
+ # unravel batch dim
235
+ depth_pred = list(depth_pred)
236
+ return depth_pred