Add files using upload-large-folder tool
Browse files- torch_hub/facebookresearch_dinov2_main/dinov2/eval/segmentation/__init__.py +4 -0
- torch_hub/facebookresearch_dinov2_main/dinov2/eval/segmentation_m2f/core/box/samplers/base_sampler.py +92 -0
- torch_hub/facebookresearch_dinov2_main/dinov2/eval/segmentation_m2f/models/__init__.py +11 -0
- torch_hub/facebookresearch_dinov2_main/dinov2/eval/segmentation_m2f/models/backbones/adapter_modules.py +442 -0
- torch_hub/facebookresearch_dinov2_main/dinov2/eval/segmentation_m2f/models/backbones/drop_path.py +32 -0
- torch_hub/facebookresearch_dinov2_main/dinov2/eval/segmentation_m2f/models/backbones/vit_adapter.py +217 -0
- torch_hub/facebookresearch_dinov2_main/dinov2/eval/segmentation_m2f/models/decode_heads/mask2former_head.py +544 -0
- torch_hub/facebookresearch_dinov2_main/dinov2/eval/segmentation_m2f/models/utils/point_sample.py +86 -0
- torch_hub/facebookresearch_dinov2_main/dinov2/hub/__pycache__/__init__.cpython-310.pyc +0 -0
- torch_hub/facebookresearch_dinov2_main/dinov2/hub/__pycache__/backbones.cpython-310.pyc +0 -0
- torch_hub/facebookresearch_dinov2_main/dinov2/hub/__pycache__/classifiers.cpython-310.pyc +0 -0
- torch_hub/facebookresearch_dinov2_main/dinov2/hub/__pycache__/dinotxt.cpython-310.pyc +0 -0
- torch_hub/facebookresearch_dinov2_main/dinov2/hub/__pycache__/utils.cpython-310.pyc +0 -0
- torch_hub/facebookresearch_dinov2_main/dinov2/hub/depth/__init__.py +7 -0
- torch_hub/facebookresearch_dinov2_main/dinov2/hub/depth/__pycache__/ops.cpython-310.pyc +0 -0
- torch_hub/facebookresearch_dinov2_main/dinov2/hub/depth/decode_heads.py +747 -0
- torch_hub/facebookresearch_dinov2_main/dinov2/hub/depth/encoder_decoder.py +351 -0
- torch_hub/facebookresearch_dinov2_main/dinov2/hub/depth/ops.py +28 -0
- torch_hub/facebookresearch_dinov2_main/dinov2/hub/utils.py +39 -0
- torch_hub/facebookresearch_dinov2_main/dinov2/hub/xray_dino/__pycache__/backbones.cpython-310.pyc +0 -0
- torch_hub/facebookresearch_dinov2_main/dinov2/hub/xray_dino/backbones.py +28 -0
- torch_hub/facebookresearch_dinov2_main/dinov2/layers/__init__.py +12 -0
- torch_hub/facebookresearch_dinov2_main/dinov2/layers/__pycache__/dino_head.cpython-310.pyc +0 -0
- torch_hub/facebookresearch_dinov2_main/dinov2/layers/__pycache__/layer_scale.cpython-310.pyc +0 -0
- torch_hub/facebookresearch_dinov2_main/dinov2/layers/__pycache__/swiglu_ffn.cpython-310.pyc +0 -0
- torch_hub/facebookresearch_dinov2_main/dinov2/layers/attention.py +99 -0
- torch_hub/facebookresearch_dinov2_main/dinov2/layers/patch_embed.py +88 -0
- torch_hub/facebookresearch_dinov2_main/dinov2/logging/__init__.py +102 -0
- torch_hub/facebookresearch_dinov2_main/dinov2/logging/helpers.py +194 -0
- torch_hub/facebookresearch_dinov2_main/dinov2/loss/__init__.py +8 -0
- torch_hub/facebookresearch_dinov2_main/dinov2/loss/ibot_patch_loss.py +151 -0
- torch_hub/facebookresearch_dinov2_main/dinov2/models/vision_transformer.py +428 -0
- torch_hub/facebookresearch_dinov2_main/dinov2/run/__init__.py +4 -0
- torch_hub/facebookresearch_dinov2_main/dinov2/run/eval/cell_dino/knn.py +59 -0
- torch_hub/facebookresearch_dinov2_main/dinov2/run/eval/cell_dino/linear.py +59 -0
- torch_hub/facebookresearch_dinov2_main/dinov2/run/submit.py +122 -0
- torch_hub/facebookresearch_dinov2_main/dinov2/run/train/train.py +59 -0
- torch_hub/facebookresearch_dinov2_main/dinov2/thirdparty/CLIP/LICENSE +21 -0
- torch_hub/facebookresearch_dinov2_main/dinov2/thirdparty/CLIP/clip/simple_tokenizer.py +135 -0
- torch_hub/facebookresearch_dinov2_main/dinov2/utils/__init__.py +4 -0
- torch_hub/facebookresearch_dinov2_main/dinov2/utils/checkpoint.py +63 -0
- torch_hub/facebookresearch_dinov2_main/dinov2/utils/cluster.py +95 -0
- torch_hub/facebookresearch_dinov2_main/dinov2/utils/config.py +72 -0
- torch_hub/facebookresearch_dinov2_main/dinov2/utils/param_groups.py +103 -0
- torch_hub/facebookresearch_dinov2_main/docs/README_CHANNEL_ADAPTIVE_DINO.md +156 -0
- torch_hub/facebookresearch_dinov2_main/notebooks/cell_dino/inference.ipynb +179 -0
- torch_hub/facebookresearch_dinov2_main/notebooks/depth_estimation.ipynb +0 -0
- torch_hub/facebookresearch_dinov2_main/notebooks/semantic_segmentation.ipynb +0 -0
- torch_hub/facebookresearch_dinov2_main/scripts/cell_dino/launcher_knn_eval_on_chammi.sh +34 -0
- torch_hub/trusted_list +0 -0
torch_hub/facebookresearch_dinov2_main/dinov2/eval/segmentation/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
torch_hub/facebookresearch_dinov2_main/dinov2/eval/segmentation_m2f/core/box/samplers/base_sampler.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from abc import ABCMeta, abstractmethod
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
from .sampling_result import SamplingResult
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class BaseSampler(metaclass=ABCMeta):
|
| 14 |
+
"""Base class of samplers."""
|
| 15 |
+
|
| 16 |
+
def __init__(self, num, pos_fraction, neg_pos_ub=-1, add_gt_as_proposals=True, **kwargs):
|
| 17 |
+
self.num = num
|
| 18 |
+
self.pos_fraction = pos_fraction
|
| 19 |
+
self.neg_pos_ub = neg_pos_ub
|
| 20 |
+
self.add_gt_as_proposals = add_gt_as_proposals
|
| 21 |
+
self.pos_sampler = self
|
| 22 |
+
self.neg_sampler = self
|
| 23 |
+
|
| 24 |
+
@abstractmethod
|
| 25 |
+
def _sample_pos(self, assign_result, num_expected, **kwargs):
|
| 26 |
+
"""Sample positive samples."""
|
| 27 |
+
pass
|
| 28 |
+
|
| 29 |
+
@abstractmethod
|
| 30 |
+
def _sample_neg(self, assign_result, num_expected, **kwargs):
|
| 31 |
+
"""Sample negative samples."""
|
| 32 |
+
pass
|
| 33 |
+
|
| 34 |
+
def sample(self, assign_result, bboxes, gt_bboxes, gt_labels=None, **kwargs):
|
| 35 |
+
"""Sample positive and negative bboxes.
|
| 36 |
+
|
| 37 |
+
This is a simple implementation of bbox sampling given candidates,
|
| 38 |
+
assigning results and ground truth bboxes.
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
assign_result (:obj:`AssignResult`): Bbox assigning results.
|
| 42 |
+
bboxes (Tensor): Boxes to be sampled from.
|
| 43 |
+
gt_bboxes (Tensor): Ground truth bboxes.
|
| 44 |
+
gt_labels (Tensor, optional): Class labels of ground truth bboxes.
|
| 45 |
+
|
| 46 |
+
Returns:
|
| 47 |
+
:obj:`SamplingResult`: Sampling result.
|
| 48 |
+
|
| 49 |
+
Example:
|
| 50 |
+
>>> from mmdet.core.bbox import RandomSampler
|
| 51 |
+
>>> from mmdet.core.bbox import AssignResult
|
| 52 |
+
>>> from mmdet.core.bbox.demodata import ensure_rng, random_boxes
|
| 53 |
+
>>> rng = ensure_rng(None)
|
| 54 |
+
>>> assign_result = AssignResult.random(rng=rng)
|
| 55 |
+
>>> bboxes = random_boxes(assign_result.num_preds, rng=rng)
|
| 56 |
+
>>> gt_bboxes = random_boxes(assign_result.num_gts, rng=rng)
|
| 57 |
+
>>> gt_labels = None
|
| 58 |
+
>>> self = RandomSampler(num=32, pos_fraction=0.5, neg_pos_ub=-1,
|
| 59 |
+
>>> add_gt_as_proposals=False)
|
| 60 |
+
>>> self = self.sample(assign_result, bboxes, gt_bboxes, gt_labels)
|
| 61 |
+
"""
|
| 62 |
+
if len(bboxes.shape) < 2:
|
| 63 |
+
bboxes = bboxes[None, :]
|
| 64 |
+
|
| 65 |
+
bboxes = bboxes[:, :4]
|
| 66 |
+
|
| 67 |
+
gt_flags = bboxes.new_zeros((bboxes.shape[0],), dtype=torch.uint8)
|
| 68 |
+
if self.add_gt_as_proposals and len(gt_bboxes) > 0:
|
| 69 |
+
if gt_labels is None:
|
| 70 |
+
raise ValueError("gt_labels must be given when add_gt_as_proposals is True")
|
| 71 |
+
bboxes = torch.cat([gt_bboxes, bboxes], dim=0)
|
| 72 |
+
assign_result.add_gt_(gt_labels)
|
| 73 |
+
gt_ones = bboxes.new_ones(gt_bboxes.shape[0], dtype=torch.uint8)
|
| 74 |
+
gt_flags = torch.cat([gt_ones, gt_flags])
|
| 75 |
+
|
| 76 |
+
num_expected_pos = int(self.num * self.pos_fraction)
|
| 77 |
+
pos_inds = self.pos_sampler._sample_pos(assign_result, num_expected_pos, bboxes=bboxes, **kwargs)
|
| 78 |
+
# We found that sampled indices have duplicated items occasionally.
|
| 79 |
+
# (may be a bug of PyTorch)
|
| 80 |
+
pos_inds = pos_inds.unique()
|
| 81 |
+
num_sampled_pos = pos_inds.numel()
|
| 82 |
+
num_expected_neg = self.num - num_sampled_pos
|
| 83 |
+
if self.neg_pos_ub >= 0:
|
| 84 |
+
_pos = max(1, num_sampled_pos)
|
| 85 |
+
neg_upper_bound = int(self.neg_pos_ub * _pos)
|
| 86 |
+
if num_expected_neg > neg_upper_bound:
|
| 87 |
+
num_expected_neg = neg_upper_bound
|
| 88 |
+
neg_inds = self.neg_sampler._sample_neg(assign_result, num_expected_neg, bboxes=bboxes, **kwargs)
|
| 89 |
+
neg_inds = neg_inds.unique()
|
| 90 |
+
|
| 91 |
+
sampling_result = SamplingResult(pos_inds, neg_inds, bboxes, gt_bboxes, assign_result, gt_flags)
|
| 92 |
+
return sampling_result
|
torch_hub/facebookresearch_dinov2_main/dinov2/eval/segmentation_m2f/models/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from .backbones import * # noqa: F403
|
| 7 |
+
from .builder import MASK_ASSIGNERS, MATCH_COST, TRANSFORMER, build_assigner, build_match_cost
|
| 8 |
+
from .decode_heads import * # noqa: F403
|
| 9 |
+
from .losses import * # noqa: F403
|
| 10 |
+
from .plugins import * # noqa: F403
|
| 11 |
+
from .segmentors import * # noqa: F403
|
torch_hub/facebookresearch_dinov2_main/dinov2/eval/segmentation_m2f/models/backbones/adapter_modules.py
ADDED
|
@@ -0,0 +1,442 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from functools import partial
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
import torch.utils.checkpoint as cp
|
| 11 |
+
|
| 12 |
+
from ...ops.modules import MSDeformAttn
|
| 13 |
+
from .drop_path import DropPath
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def get_reference_points(spatial_shapes, device):
|
| 17 |
+
reference_points_list = []
|
| 18 |
+
for lvl, (H_, W_) in enumerate(spatial_shapes):
|
| 19 |
+
ref_y, ref_x = torch.meshgrid(
|
| 20 |
+
torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device),
|
| 21 |
+
torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device),
|
| 22 |
+
)
|
| 23 |
+
ref_y = ref_y.reshape(-1)[None] / H_
|
| 24 |
+
ref_x = ref_x.reshape(-1)[None] / W_
|
| 25 |
+
ref = torch.stack((ref_x, ref_y), -1)
|
| 26 |
+
reference_points_list.append(ref)
|
| 27 |
+
reference_points = torch.cat(reference_points_list, 1)
|
| 28 |
+
reference_points = reference_points[:, :, None]
|
| 29 |
+
return reference_points
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def deform_inputs(x, patch_size):
|
| 33 |
+
bs, c, h, w = x.shape
|
| 34 |
+
spatial_shapes = torch.as_tensor(
|
| 35 |
+
[(h // 8, w // 8), (h // 16, w // 16), (h // 32, w // 32)], dtype=torch.long, device=x.device
|
| 36 |
+
)
|
| 37 |
+
level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]))
|
| 38 |
+
reference_points = get_reference_points([(h // patch_size, w // patch_size)], x.device)
|
| 39 |
+
deform_inputs1 = [reference_points, spatial_shapes, level_start_index]
|
| 40 |
+
|
| 41 |
+
spatial_shapes = torch.as_tensor([(h // patch_size, w // patch_size)], dtype=torch.long, device=x.device)
|
| 42 |
+
level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]))
|
| 43 |
+
reference_points = get_reference_points([(h // 8, w // 8), (h // 16, w // 16), (h // 32, w // 32)], x.device)
|
| 44 |
+
deform_inputs2 = [reference_points, spatial_shapes, level_start_index]
|
| 45 |
+
|
| 46 |
+
return deform_inputs1, deform_inputs2
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class ConvFFN(nn.Module):
|
| 50 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0):
|
| 51 |
+
super().__init__()
|
| 52 |
+
out_features = out_features or in_features
|
| 53 |
+
hidden_features = hidden_features or in_features
|
| 54 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
| 55 |
+
self.dwconv = DWConv(hidden_features)
|
| 56 |
+
self.act = act_layer()
|
| 57 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
| 58 |
+
self.drop = nn.Dropout(drop)
|
| 59 |
+
|
| 60 |
+
def forward(self, x, H, W):
|
| 61 |
+
x = self.fc1(x)
|
| 62 |
+
x = self.dwconv(x, H, W)
|
| 63 |
+
x = self.act(x)
|
| 64 |
+
x = self.drop(x)
|
| 65 |
+
x = self.fc2(x)
|
| 66 |
+
x = self.drop(x)
|
| 67 |
+
return x
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class DWConv(nn.Module):
|
| 71 |
+
def __init__(self, dim=768):
|
| 72 |
+
super().__init__()
|
| 73 |
+
self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
|
| 74 |
+
|
| 75 |
+
def forward(self, x, H, W):
|
| 76 |
+
B, N, C = x.shape
|
| 77 |
+
n = N // 21
|
| 78 |
+
x1 = x[:, 0 : 16 * n, :].transpose(1, 2).view(B, C, H * 2, W * 2).contiguous()
|
| 79 |
+
x2 = x[:, 16 * n : 20 * n, :].transpose(1, 2).view(B, C, H, W).contiguous()
|
| 80 |
+
x3 = x[:, 20 * n :, :].transpose(1, 2).view(B, C, H // 2, W // 2).contiguous()
|
| 81 |
+
x1 = self.dwconv(x1).flatten(2).transpose(1, 2)
|
| 82 |
+
x2 = self.dwconv(x2).flatten(2).transpose(1, 2)
|
| 83 |
+
x3 = self.dwconv(x3).flatten(2).transpose(1, 2)
|
| 84 |
+
x = torch.cat([x1, x2, x3], dim=1)
|
| 85 |
+
return x
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
class Extractor(nn.Module):
|
| 89 |
+
def __init__(
|
| 90 |
+
self,
|
| 91 |
+
dim,
|
| 92 |
+
num_heads=6,
|
| 93 |
+
n_points=4,
|
| 94 |
+
n_levels=1,
|
| 95 |
+
deform_ratio=1.0,
|
| 96 |
+
with_cffn=True,
|
| 97 |
+
cffn_ratio=0.25,
|
| 98 |
+
drop=0.0,
|
| 99 |
+
drop_path=0.0,
|
| 100 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
| 101 |
+
with_cp=False,
|
| 102 |
+
):
|
| 103 |
+
super().__init__()
|
| 104 |
+
self.query_norm = norm_layer(dim)
|
| 105 |
+
self.feat_norm = norm_layer(dim)
|
| 106 |
+
self.attn = MSDeformAttn(
|
| 107 |
+
d_model=dim, n_levels=n_levels, n_heads=num_heads, n_points=n_points, ratio=deform_ratio
|
| 108 |
+
)
|
| 109 |
+
self.with_cffn = with_cffn
|
| 110 |
+
self.with_cp = with_cp
|
| 111 |
+
if with_cffn:
|
| 112 |
+
self.ffn = ConvFFN(in_features=dim, hidden_features=int(dim * cffn_ratio), drop=drop)
|
| 113 |
+
self.ffn_norm = norm_layer(dim)
|
| 114 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
| 115 |
+
|
| 116 |
+
def forward(self, query, reference_points, feat, spatial_shapes, level_start_index, H, W):
|
| 117 |
+
def _inner_forward(query, feat):
|
| 118 |
+
|
| 119 |
+
attn = self.attn(
|
| 120 |
+
self.query_norm(query), reference_points, self.feat_norm(feat), spatial_shapes, level_start_index, None
|
| 121 |
+
)
|
| 122 |
+
query = query + attn
|
| 123 |
+
|
| 124 |
+
if self.with_cffn:
|
| 125 |
+
query = query + self.drop_path(self.ffn(self.ffn_norm(query), H, W))
|
| 126 |
+
return query
|
| 127 |
+
|
| 128 |
+
if self.with_cp and query.requires_grad:
|
| 129 |
+
query = cp.checkpoint(_inner_forward, query, feat)
|
| 130 |
+
else:
|
| 131 |
+
query = _inner_forward(query, feat)
|
| 132 |
+
|
| 133 |
+
return query
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
class Injector(nn.Module):
|
| 137 |
+
def __init__(
|
| 138 |
+
self,
|
| 139 |
+
dim,
|
| 140 |
+
num_heads=6,
|
| 141 |
+
n_points=4,
|
| 142 |
+
n_levels=1,
|
| 143 |
+
deform_ratio=1.0,
|
| 144 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
| 145 |
+
init_values=0.0,
|
| 146 |
+
with_cp=False,
|
| 147 |
+
):
|
| 148 |
+
super().__init__()
|
| 149 |
+
self.with_cp = with_cp
|
| 150 |
+
self.query_norm = norm_layer(dim)
|
| 151 |
+
self.feat_norm = norm_layer(dim)
|
| 152 |
+
self.attn = MSDeformAttn(
|
| 153 |
+
d_model=dim, n_levels=n_levels, n_heads=num_heads, n_points=n_points, ratio=deform_ratio
|
| 154 |
+
)
|
| 155 |
+
self.gamma = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
|
| 156 |
+
|
| 157 |
+
def forward(self, query, reference_points, feat, spatial_shapes, level_start_index):
|
| 158 |
+
def _inner_forward(query, feat):
|
| 159 |
+
|
| 160 |
+
attn = self.attn(
|
| 161 |
+
self.query_norm(query), reference_points, self.feat_norm(feat), spatial_shapes, level_start_index, None
|
| 162 |
+
)
|
| 163 |
+
return query + self.gamma * attn
|
| 164 |
+
|
| 165 |
+
if self.with_cp and query.requires_grad:
|
| 166 |
+
query = cp.checkpoint(_inner_forward, query, feat)
|
| 167 |
+
else:
|
| 168 |
+
query = _inner_forward(query, feat)
|
| 169 |
+
|
| 170 |
+
return query
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
class InteractionBlock(nn.Module):
|
| 174 |
+
def __init__(
|
| 175 |
+
self,
|
| 176 |
+
dim,
|
| 177 |
+
num_heads=6,
|
| 178 |
+
n_points=4,
|
| 179 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
| 180 |
+
drop=0.0,
|
| 181 |
+
drop_path=0.0,
|
| 182 |
+
with_cffn=True,
|
| 183 |
+
cffn_ratio=0.25,
|
| 184 |
+
init_values=0.0,
|
| 185 |
+
deform_ratio=1.0,
|
| 186 |
+
extra_extractor=False,
|
| 187 |
+
with_cp=False,
|
| 188 |
+
):
|
| 189 |
+
super().__init__()
|
| 190 |
+
|
| 191 |
+
self.injector = Injector(
|
| 192 |
+
dim=dim,
|
| 193 |
+
n_levels=3,
|
| 194 |
+
num_heads=num_heads,
|
| 195 |
+
init_values=init_values,
|
| 196 |
+
n_points=n_points,
|
| 197 |
+
norm_layer=norm_layer,
|
| 198 |
+
deform_ratio=deform_ratio,
|
| 199 |
+
with_cp=with_cp,
|
| 200 |
+
)
|
| 201 |
+
self.extractor = Extractor(
|
| 202 |
+
dim=dim,
|
| 203 |
+
n_levels=1,
|
| 204 |
+
num_heads=num_heads,
|
| 205 |
+
n_points=n_points,
|
| 206 |
+
norm_layer=norm_layer,
|
| 207 |
+
deform_ratio=deform_ratio,
|
| 208 |
+
with_cffn=with_cffn,
|
| 209 |
+
cffn_ratio=cffn_ratio,
|
| 210 |
+
drop=drop,
|
| 211 |
+
drop_path=drop_path,
|
| 212 |
+
with_cp=with_cp,
|
| 213 |
+
)
|
| 214 |
+
if extra_extractor:
|
| 215 |
+
self.extra_extractors = nn.Sequential(
|
| 216 |
+
*[
|
| 217 |
+
Extractor(
|
| 218 |
+
dim=dim,
|
| 219 |
+
num_heads=num_heads,
|
| 220 |
+
n_points=n_points,
|
| 221 |
+
norm_layer=norm_layer,
|
| 222 |
+
with_cffn=with_cffn,
|
| 223 |
+
cffn_ratio=cffn_ratio,
|
| 224 |
+
deform_ratio=deform_ratio,
|
| 225 |
+
drop=drop,
|
| 226 |
+
drop_path=drop_path,
|
| 227 |
+
with_cp=with_cp,
|
| 228 |
+
)
|
| 229 |
+
for _ in range(2)
|
| 230 |
+
]
|
| 231 |
+
)
|
| 232 |
+
else:
|
| 233 |
+
self.extra_extractors = None
|
| 234 |
+
|
| 235 |
+
def forward(self, x, c, blocks, deform_inputs1, deform_inputs2, H_c, W_c, H_toks, W_toks):
|
| 236 |
+
x = self.injector(
|
| 237 |
+
query=x,
|
| 238 |
+
reference_points=deform_inputs1[0],
|
| 239 |
+
feat=c,
|
| 240 |
+
spatial_shapes=deform_inputs1[1],
|
| 241 |
+
level_start_index=deform_inputs1[2],
|
| 242 |
+
)
|
| 243 |
+
for idx, blk in enumerate(blocks):
|
| 244 |
+
x = blk(x, H_toks, W_toks)
|
| 245 |
+
c = self.extractor(
|
| 246 |
+
query=c,
|
| 247 |
+
reference_points=deform_inputs2[0],
|
| 248 |
+
feat=x,
|
| 249 |
+
spatial_shapes=deform_inputs2[1],
|
| 250 |
+
level_start_index=deform_inputs2[2],
|
| 251 |
+
H=H_c,
|
| 252 |
+
W=W_c,
|
| 253 |
+
)
|
| 254 |
+
if self.extra_extractors is not None:
|
| 255 |
+
for extractor in self.extra_extractors:
|
| 256 |
+
c = extractor(
|
| 257 |
+
query=c,
|
| 258 |
+
reference_points=deform_inputs2[0],
|
| 259 |
+
feat=x,
|
| 260 |
+
spatial_shapes=deform_inputs2[1],
|
| 261 |
+
level_start_index=deform_inputs2[2],
|
| 262 |
+
H=H_c,
|
| 263 |
+
W=W_c,
|
| 264 |
+
)
|
| 265 |
+
return x, c
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
class InteractionBlockWithCls(nn.Module):
|
| 269 |
+
def __init__(
|
| 270 |
+
self,
|
| 271 |
+
dim,
|
| 272 |
+
num_heads=6,
|
| 273 |
+
n_points=4,
|
| 274 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
| 275 |
+
drop=0.0,
|
| 276 |
+
drop_path=0.0,
|
| 277 |
+
with_cffn=True,
|
| 278 |
+
cffn_ratio=0.25,
|
| 279 |
+
init_values=0.0,
|
| 280 |
+
deform_ratio=1.0,
|
| 281 |
+
extra_extractor=False,
|
| 282 |
+
with_cp=False,
|
| 283 |
+
):
|
| 284 |
+
super().__init__()
|
| 285 |
+
|
| 286 |
+
self.injector = Injector(
|
| 287 |
+
dim=dim,
|
| 288 |
+
n_levels=3,
|
| 289 |
+
num_heads=num_heads,
|
| 290 |
+
init_values=init_values,
|
| 291 |
+
n_points=n_points,
|
| 292 |
+
norm_layer=norm_layer,
|
| 293 |
+
deform_ratio=deform_ratio,
|
| 294 |
+
with_cp=with_cp,
|
| 295 |
+
)
|
| 296 |
+
self.extractor = Extractor(
|
| 297 |
+
dim=dim,
|
| 298 |
+
n_levels=1,
|
| 299 |
+
num_heads=num_heads,
|
| 300 |
+
n_points=n_points,
|
| 301 |
+
norm_layer=norm_layer,
|
| 302 |
+
deform_ratio=deform_ratio,
|
| 303 |
+
with_cffn=with_cffn,
|
| 304 |
+
cffn_ratio=cffn_ratio,
|
| 305 |
+
drop=drop,
|
| 306 |
+
drop_path=drop_path,
|
| 307 |
+
with_cp=with_cp,
|
| 308 |
+
)
|
| 309 |
+
if extra_extractor:
|
| 310 |
+
self.extra_extractors = nn.Sequential(
|
| 311 |
+
*[
|
| 312 |
+
Extractor(
|
| 313 |
+
dim=dim,
|
| 314 |
+
num_heads=num_heads,
|
| 315 |
+
n_points=n_points,
|
| 316 |
+
norm_layer=norm_layer,
|
| 317 |
+
with_cffn=with_cffn,
|
| 318 |
+
cffn_ratio=cffn_ratio,
|
| 319 |
+
deform_ratio=deform_ratio,
|
| 320 |
+
drop=drop,
|
| 321 |
+
drop_path=drop_path,
|
| 322 |
+
with_cp=with_cp,
|
| 323 |
+
)
|
| 324 |
+
for _ in range(2)
|
| 325 |
+
]
|
| 326 |
+
)
|
| 327 |
+
else:
|
| 328 |
+
self.extra_extractors = None
|
| 329 |
+
|
| 330 |
+
def forward(self, x, c, cls, blocks, deform_inputs1, deform_inputs2, H_c, W_c, H_toks, W_toks):
|
| 331 |
+
x = self.injector(
|
| 332 |
+
query=x,
|
| 333 |
+
reference_points=deform_inputs1[0],
|
| 334 |
+
feat=c,
|
| 335 |
+
spatial_shapes=deform_inputs1[1],
|
| 336 |
+
level_start_index=deform_inputs1[2],
|
| 337 |
+
)
|
| 338 |
+
x = torch.cat((cls, x), dim=1)
|
| 339 |
+
for idx, blk in enumerate(blocks):
|
| 340 |
+
x = blk(x, H_toks, W_toks)
|
| 341 |
+
cls, x = (
|
| 342 |
+
x[
|
| 343 |
+
:,
|
| 344 |
+
:1,
|
| 345 |
+
],
|
| 346 |
+
x[
|
| 347 |
+
:,
|
| 348 |
+
1:,
|
| 349 |
+
],
|
| 350 |
+
)
|
| 351 |
+
c = self.extractor(
|
| 352 |
+
query=c,
|
| 353 |
+
reference_points=deform_inputs2[0],
|
| 354 |
+
feat=x,
|
| 355 |
+
spatial_shapes=deform_inputs2[1],
|
| 356 |
+
level_start_index=deform_inputs2[2],
|
| 357 |
+
H=H_c,
|
| 358 |
+
W=W_c,
|
| 359 |
+
)
|
| 360 |
+
if self.extra_extractors is not None:
|
| 361 |
+
for extractor in self.extra_extractors:
|
| 362 |
+
c = extractor(
|
| 363 |
+
query=c,
|
| 364 |
+
reference_points=deform_inputs2[0],
|
| 365 |
+
feat=x,
|
| 366 |
+
spatial_shapes=deform_inputs2[1],
|
| 367 |
+
level_start_index=deform_inputs2[2],
|
| 368 |
+
H=H_c,
|
| 369 |
+
W=W_c,
|
| 370 |
+
)
|
| 371 |
+
return x, c, cls
|
| 372 |
+
|
| 373 |
+
|
| 374 |
+
class SpatialPriorModule(nn.Module):
|
| 375 |
+
def __init__(self, inplanes=64, embed_dim=384, with_cp=False):
|
| 376 |
+
super().__init__()
|
| 377 |
+
self.with_cp = with_cp
|
| 378 |
+
|
| 379 |
+
self.stem = nn.Sequential(
|
| 380 |
+
*[
|
| 381 |
+
nn.Conv2d(3, inplanes, kernel_size=3, stride=2, padding=1, bias=False),
|
| 382 |
+
nn.SyncBatchNorm(inplanes),
|
| 383 |
+
nn.ReLU(inplace=True),
|
| 384 |
+
nn.Conv2d(inplanes, inplanes, kernel_size=3, stride=1, padding=1, bias=False),
|
| 385 |
+
nn.SyncBatchNorm(inplanes),
|
| 386 |
+
nn.ReLU(inplace=True),
|
| 387 |
+
nn.Conv2d(inplanes, inplanes, kernel_size=3, stride=1, padding=1, bias=False),
|
| 388 |
+
nn.SyncBatchNorm(inplanes),
|
| 389 |
+
nn.ReLU(inplace=True),
|
| 390 |
+
nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
|
| 391 |
+
]
|
| 392 |
+
)
|
| 393 |
+
self.conv2 = nn.Sequential(
|
| 394 |
+
*[
|
| 395 |
+
nn.Conv2d(inplanes, 2 * inplanes, kernel_size=3, stride=2, padding=1, bias=False),
|
| 396 |
+
nn.SyncBatchNorm(2 * inplanes),
|
| 397 |
+
nn.ReLU(inplace=True),
|
| 398 |
+
]
|
| 399 |
+
)
|
| 400 |
+
self.conv3 = nn.Sequential(
|
| 401 |
+
*[
|
| 402 |
+
nn.Conv2d(2 * inplanes, 4 * inplanes, kernel_size=3, stride=2, padding=1, bias=False),
|
| 403 |
+
nn.SyncBatchNorm(4 * inplanes),
|
| 404 |
+
nn.ReLU(inplace=True),
|
| 405 |
+
]
|
| 406 |
+
)
|
| 407 |
+
self.conv4 = nn.Sequential(
|
| 408 |
+
*[
|
| 409 |
+
nn.Conv2d(4 * inplanes, 4 * inplanes, kernel_size=3, stride=2, padding=1, bias=False),
|
| 410 |
+
nn.SyncBatchNorm(4 * inplanes),
|
| 411 |
+
nn.ReLU(inplace=True),
|
| 412 |
+
]
|
| 413 |
+
)
|
| 414 |
+
self.fc1 = nn.Conv2d(inplanes, embed_dim, kernel_size=1, stride=1, padding=0, bias=True)
|
| 415 |
+
self.fc2 = nn.Conv2d(2 * inplanes, embed_dim, kernel_size=1, stride=1, padding=0, bias=True)
|
| 416 |
+
self.fc3 = nn.Conv2d(4 * inplanes, embed_dim, kernel_size=1, stride=1, padding=0, bias=True)
|
| 417 |
+
self.fc4 = nn.Conv2d(4 * inplanes, embed_dim, kernel_size=1, stride=1, padding=0, bias=True)
|
| 418 |
+
|
| 419 |
+
def forward(self, x):
|
| 420 |
+
def _inner_forward(x):
|
| 421 |
+
c1 = self.stem(x)
|
| 422 |
+
c2 = self.conv2(c1)
|
| 423 |
+
c3 = self.conv3(c2)
|
| 424 |
+
c4 = self.conv4(c3)
|
| 425 |
+
c1 = self.fc1(c1)
|
| 426 |
+
c2 = self.fc2(c2)
|
| 427 |
+
c3 = self.fc3(c3)
|
| 428 |
+
c4 = self.fc4(c4)
|
| 429 |
+
|
| 430 |
+
bs, dim, _, _ = c1.shape
|
| 431 |
+
# c1 = c1.view(bs, dim, -1).transpose(1, 2) # 4s
|
| 432 |
+
c2 = c2.view(bs, dim, -1).transpose(1, 2) # 8s
|
| 433 |
+
c3 = c3.view(bs, dim, -1).transpose(1, 2) # 16s
|
| 434 |
+
c4 = c4.view(bs, dim, -1).transpose(1, 2) # 32s
|
| 435 |
+
|
| 436 |
+
return c1, c2, c3, c4
|
| 437 |
+
|
| 438 |
+
if self.with_cp and x.requires_grad:
|
| 439 |
+
outs = cp.checkpoint(_inner_forward, x)
|
| 440 |
+
else:
|
| 441 |
+
outs = _inner_forward(x)
|
| 442 |
+
return outs
|
torch_hub/facebookresearch_dinov2_main/dinov2/eval/segmentation_m2f/models/backbones/drop_path.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
# References:
|
| 7 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
| 8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
|
| 9 |
+
|
| 10 |
+
from torch import nn
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def drop_path(x, drop_prob: float = 0.0, training: bool = False):
|
| 14 |
+
if drop_prob == 0.0 or not training:
|
| 15 |
+
return x
|
| 16 |
+
keep_prob = 1 - drop_prob
|
| 17 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
| 18 |
+
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
| 19 |
+
if keep_prob > 0.0:
|
| 20 |
+
random_tensor.div_(keep_prob)
|
| 21 |
+
return x * random_tensor
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class DropPath(nn.Module):
|
| 25 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
| 26 |
+
|
| 27 |
+
def __init__(self, drop_prob: float = 0.0):
|
| 28 |
+
super(DropPath, self).__init__()
|
| 29 |
+
self.drop_prob = drop_prob
|
| 30 |
+
|
| 31 |
+
def forward(self, x):
|
| 32 |
+
return drop_path(x, self.drop_prob, self.training)
|
torch_hub/facebookresearch_dinov2_main/dinov2/eval/segmentation_m2f/models/backbones/vit_adapter.py
ADDED
|
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import math
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
from mmseg.models.builder import BACKBONES
|
| 12 |
+
from torch.nn.init import normal_
|
| 13 |
+
|
| 14 |
+
from ...ops.modules import MSDeformAttn
|
| 15 |
+
from .adapter_modules import InteractionBlock, InteractionBlockWithCls, SpatialPriorModule, deform_inputs
|
| 16 |
+
from .vit import TIMMVisionTransformer
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@BACKBONES.register_module()
|
| 20 |
+
class ViTAdapter(TIMMVisionTransformer):
|
| 21 |
+
def __init__(
|
| 22 |
+
self,
|
| 23 |
+
pretrain_size=224,
|
| 24 |
+
num_heads=12,
|
| 25 |
+
conv_inplane=64,
|
| 26 |
+
n_points=4,
|
| 27 |
+
deform_num_heads=6,
|
| 28 |
+
init_values=0.0,
|
| 29 |
+
interaction_indexes=None,
|
| 30 |
+
with_cffn=True,
|
| 31 |
+
cffn_ratio=0.25,
|
| 32 |
+
deform_ratio=1.0,
|
| 33 |
+
add_vit_feature=True,
|
| 34 |
+
pretrained=None,
|
| 35 |
+
use_extra_extractor=True,
|
| 36 |
+
freeze_vit=False,
|
| 37 |
+
use_cls=True,
|
| 38 |
+
with_cp=False,
|
| 39 |
+
*args,
|
| 40 |
+
**kwargs
|
| 41 |
+
):
|
| 42 |
+
|
| 43 |
+
super().__init__(num_heads=num_heads, pretrained=pretrained, with_cp=with_cp, *args, **kwargs)
|
| 44 |
+
if freeze_vit:
|
| 45 |
+
for param in self.parameters():
|
| 46 |
+
param.requires_grad = False
|
| 47 |
+
|
| 48 |
+
# self.num_classes = 80
|
| 49 |
+
self.use_cls = use_cls
|
| 50 |
+
if not self.use_cls:
|
| 51 |
+
self.cls_token = None
|
| 52 |
+
self.num_block = len(self.blocks)
|
| 53 |
+
self.pretrain_size = (pretrain_size, pretrain_size)
|
| 54 |
+
self.interaction_indexes = interaction_indexes
|
| 55 |
+
self.add_vit_feature = add_vit_feature
|
| 56 |
+
embed_dim = self.embed_dim
|
| 57 |
+
|
| 58 |
+
block_fn = InteractionBlockWithCls if use_cls else InteractionBlock
|
| 59 |
+
|
| 60 |
+
self.level_embed = nn.Parameter(torch.zeros(3, embed_dim))
|
| 61 |
+
self.spm = SpatialPriorModule(inplanes=conv_inplane, embed_dim=embed_dim, with_cp=False)
|
| 62 |
+
self.interactions = nn.Sequential(
|
| 63 |
+
*[
|
| 64 |
+
block_fn(
|
| 65 |
+
dim=embed_dim,
|
| 66 |
+
num_heads=deform_num_heads,
|
| 67 |
+
n_points=n_points,
|
| 68 |
+
init_values=init_values,
|
| 69 |
+
drop_path=self.drop_path_rate,
|
| 70 |
+
norm_layer=self.norm_layer,
|
| 71 |
+
with_cffn=with_cffn,
|
| 72 |
+
cffn_ratio=cffn_ratio,
|
| 73 |
+
deform_ratio=deform_ratio,
|
| 74 |
+
extra_extractor=((True if i == len(interaction_indexes) - 1 else False) and use_extra_extractor),
|
| 75 |
+
with_cp=with_cp,
|
| 76 |
+
)
|
| 77 |
+
for i in range(len(interaction_indexes))
|
| 78 |
+
]
|
| 79 |
+
)
|
| 80 |
+
self.up = nn.ConvTranspose2d(embed_dim, embed_dim, 2, 2)
|
| 81 |
+
self.norm1 = nn.SyncBatchNorm(embed_dim)
|
| 82 |
+
self.norm2 = nn.SyncBatchNorm(embed_dim)
|
| 83 |
+
self.norm3 = nn.SyncBatchNorm(embed_dim)
|
| 84 |
+
self.norm4 = nn.SyncBatchNorm(embed_dim)
|
| 85 |
+
|
| 86 |
+
self.up.apply(self._init_weights)
|
| 87 |
+
self.spm.apply(self._init_weights)
|
| 88 |
+
self.interactions.apply(self._init_weights)
|
| 89 |
+
self.apply(self._init_deform_weights)
|
| 90 |
+
normal_(self.level_embed)
|
| 91 |
+
|
| 92 |
+
def _init_weights(self, m):
|
| 93 |
+
if isinstance(m, nn.Linear):
|
| 94 |
+
torch.nn.init.trunc_normal_(m.weight, std=0.02)
|
| 95 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 96 |
+
nn.init.constant_(m.bias, 0)
|
| 97 |
+
elif isinstance(m, nn.LayerNorm) or isinstance(m, nn.BatchNorm2d):
|
| 98 |
+
nn.init.constant_(m.bias, 0)
|
| 99 |
+
nn.init.constant_(m.weight, 1.0)
|
| 100 |
+
elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
|
| 101 |
+
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
| 102 |
+
fan_out //= m.groups
|
| 103 |
+
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
| 104 |
+
if m.bias is not None:
|
| 105 |
+
m.bias.data.zero_()
|
| 106 |
+
|
| 107 |
+
def _get_pos_embed(self, pos_embed, H, W):
|
| 108 |
+
pos_embed = pos_embed.reshape(
|
| 109 |
+
1, self.pretrain_size[0] // self.patch_size, self.pretrain_size[1] // self.patch_size, -1
|
| 110 |
+
).permute(0, 3, 1, 2)
|
| 111 |
+
pos_embed = (
|
| 112 |
+
F.interpolate(pos_embed, size=(H, W), mode="bicubic", align_corners=False)
|
| 113 |
+
.reshape(1, -1, H * W)
|
| 114 |
+
.permute(0, 2, 1)
|
| 115 |
+
)
|
| 116 |
+
return pos_embed
|
| 117 |
+
|
| 118 |
+
def _init_deform_weights(self, m):
|
| 119 |
+
if isinstance(m, MSDeformAttn):
|
| 120 |
+
m._reset_parameters()
|
| 121 |
+
|
| 122 |
+
def _add_level_embed(self, c2, c3, c4):
|
| 123 |
+
c2 = c2 + self.level_embed[0]
|
| 124 |
+
c3 = c3 + self.level_embed[1]
|
| 125 |
+
c4 = c4 + self.level_embed[2]
|
| 126 |
+
return c2, c3, c4
|
| 127 |
+
|
| 128 |
+
def forward(self, x):
|
| 129 |
+
deform_inputs1, deform_inputs2 = deform_inputs(x, self.patch_size)
|
| 130 |
+
|
| 131 |
+
# SPM forward
|
| 132 |
+
c1, c2, c3, c4 = self.spm(x)
|
| 133 |
+
c2, c3, c4 = self._add_level_embed(c2, c3, c4)
|
| 134 |
+
c = torch.cat([c2, c3, c4], dim=1)
|
| 135 |
+
|
| 136 |
+
# Patch Embedding forward
|
| 137 |
+
H_c, W_c = x.shape[2] // 16, x.shape[3] // 16
|
| 138 |
+
x, H_toks, W_toks = self.patch_embed(x)
|
| 139 |
+
# print("H_toks, W_toks =", H_toks, W_toks)
|
| 140 |
+
bs, n, dim = x.shape
|
| 141 |
+
pos_embed = self._get_pos_embed(self.pos_embed[:, 1:], H_toks, W_toks)
|
| 142 |
+
if self.use_cls:
|
| 143 |
+
cls_token = self.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks
|
| 144 |
+
x = torch.cat((cls_token, x), dim=1)
|
| 145 |
+
pos_embed = torch.cat((self.pos_embed[:, :1], pos_embed), dim=1)
|
| 146 |
+
x = self.pos_drop(x + pos_embed)
|
| 147 |
+
# For CLIP
|
| 148 |
+
x = self.norm_pre(x)
|
| 149 |
+
|
| 150 |
+
# Interaction
|
| 151 |
+
if self.use_cls:
|
| 152 |
+
cls, x = (
|
| 153 |
+
x[
|
| 154 |
+
:,
|
| 155 |
+
:1,
|
| 156 |
+
],
|
| 157 |
+
x[
|
| 158 |
+
:,
|
| 159 |
+
1:,
|
| 160 |
+
],
|
| 161 |
+
)
|
| 162 |
+
outs = list()
|
| 163 |
+
for i, layer in enumerate(self.interactions):
|
| 164 |
+
indexes = self.interaction_indexes[i]
|
| 165 |
+
if self.use_cls:
|
| 166 |
+
x, c, cls = layer(
|
| 167 |
+
x,
|
| 168 |
+
c,
|
| 169 |
+
cls,
|
| 170 |
+
self.blocks[indexes[0] : indexes[-1] + 1],
|
| 171 |
+
deform_inputs1,
|
| 172 |
+
deform_inputs2,
|
| 173 |
+
H_c,
|
| 174 |
+
W_c,
|
| 175 |
+
H_toks,
|
| 176 |
+
W_toks,
|
| 177 |
+
)
|
| 178 |
+
else:
|
| 179 |
+
x, c = layer(
|
| 180 |
+
x,
|
| 181 |
+
c,
|
| 182 |
+
self.blocks[indexes[0] : indexes[-1] + 1],
|
| 183 |
+
deform_inputs1,
|
| 184 |
+
deform_inputs2,
|
| 185 |
+
H_c,
|
| 186 |
+
W_c,
|
| 187 |
+
H_toks,
|
| 188 |
+
W_toks,
|
| 189 |
+
)
|
| 190 |
+
outs.append(x.transpose(1, 2).view(bs, dim, H_toks, W_toks).contiguous())
|
| 191 |
+
|
| 192 |
+
# Split & Reshape
|
| 193 |
+
c2 = c[:, 0 : c2.size(1), :]
|
| 194 |
+
c3 = c[:, c2.size(1) : c2.size(1) + c3.size(1), :]
|
| 195 |
+
c4 = c[:, c2.size(1) + c3.size(1) :, :]
|
| 196 |
+
|
| 197 |
+
c2 = c2.transpose(1, 2).view(bs, dim, H_c * 2, W_c * 2).contiguous()
|
| 198 |
+
c3 = c3.transpose(1, 2).view(bs, dim, H_c, W_c).contiguous()
|
| 199 |
+
c4 = c4.transpose(1, 2).view(bs, dim, H_c // 2, W_c // 2).contiguous()
|
| 200 |
+
c1 = self.up(c2) + c1
|
| 201 |
+
|
| 202 |
+
if self.add_vit_feature:
|
| 203 |
+
x1, x2, x3, x4 = outs
|
| 204 |
+
|
| 205 |
+
x1 = F.interpolate(x1, size=(4 * H_c, 4 * W_c), mode="bilinear", align_corners=False)
|
| 206 |
+
x2 = F.interpolate(x2, size=(2 * H_c, 2 * W_c), mode="bilinear", align_corners=False)
|
| 207 |
+
x3 = F.interpolate(x3, size=(1 * H_c, 1 * W_c), mode="bilinear", align_corners=False)
|
| 208 |
+
x4 = F.interpolate(x4, size=(H_c // 2, W_c // 2), mode="bilinear", align_corners=False)
|
| 209 |
+
# print(c1.shape, c2.shape, c3.shape, c4.shape, x1.shape, x2.shape, x3.shape, x4.shape, H_c, H_toks)
|
| 210 |
+
c1, c2, c3, c4 = c1 + x1, c2 + x2, c3 + x3, c4 + x4
|
| 211 |
+
|
| 212 |
+
# Final Norm
|
| 213 |
+
f1 = self.norm1(c1)
|
| 214 |
+
f2 = self.norm2(c2)
|
| 215 |
+
f3 = self.norm3(c3)
|
| 216 |
+
f4 = self.norm4(c4)
|
| 217 |
+
return [f1, f2, f3, f4]
|
torch_hub/facebookresearch_dinov2_main/dinov2/eval/segmentation_m2f/models/decode_heads/mask2former_head.py
ADDED
|
@@ -0,0 +1,544 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import copy
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
from mmcv.cnn import Conv2d, build_plugin_layer, caffe2_xavier_init
|
| 12 |
+
from mmcv.cnn.bricks.transformer import build_positional_encoding, build_transformer_layer_sequence
|
| 13 |
+
from mmcv.ops import point_sample
|
| 14 |
+
from mmcv.runner import ModuleList, force_fp32
|
| 15 |
+
from mmseg.models.builder import HEADS, build_loss
|
| 16 |
+
from mmseg.models.decode_heads.decode_head import BaseDecodeHead
|
| 17 |
+
|
| 18 |
+
from ...core import build_sampler, multi_apply, reduce_mean
|
| 19 |
+
from ..builder import build_assigner
|
| 20 |
+
from ..utils import get_uncertain_point_coords_with_randomness
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@HEADS.register_module()
|
| 24 |
+
class Mask2FormerHead(BaseDecodeHead):
|
| 25 |
+
"""Implements the Mask2Former head.
|
| 26 |
+
|
| 27 |
+
See `Masked-attention Mask Transformer for Universal Image
|
| 28 |
+
Segmentation <https://arxiv.org/pdf/2112.01527>`_ for details.
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
in_channels (list[int]): Number of channels in the input feature map.
|
| 32 |
+
feat_channels (int): Number of channels for features.
|
| 33 |
+
out_channels (int): Number of channels for output.
|
| 34 |
+
num_things_classes (int): Number of things.
|
| 35 |
+
num_stuff_classes (int): Number of stuff.
|
| 36 |
+
num_queries (int): Number of query in Transformer decoder.
|
| 37 |
+
pixel_decoder (:obj:`mmcv.ConfigDict` | dict): Config for pixel
|
| 38 |
+
decoder. Defaults to None.
|
| 39 |
+
enforce_decoder_input_project (bool, optional): Whether to add
|
| 40 |
+
a layer to change the embed_dim of tranformer encoder in
|
| 41 |
+
pixel decoder to the embed_dim of transformer decoder.
|
| 42 |
+
Defaults to False.
|
| 43 |
+
transformer_decoder (:obj:`mmcv.ConfigDict` | dict): Config for
|
| 44 |
+
transformer decoder. Defaults to None.
|
| 45 |
+
positional_encoding (:obj:`mmcv.ConfigDict` | dict): Config for
|
| 46 |
+
transformer decoder position encoding. Defaults to None.
|
| 47 |
+
loss_cls (:obj:`mmcv.ConfigDict` | dict): Config of the classification
|
| 48 |
+
loss. Defaults to None.
|
| 49 |
+
loss_mask (:obj:`mmcv.ConfigDict` | dict): Config of the mask loss.
|
| 50 |
+
Defaults to None.
|
| 51 |
+
loss_dice (:obj:`mmcv.ConfigDict` | dict): Config of the dice loss.
|
| 52 |
+
Defaults to None.
|
| 53 |
+
train_cfg (:obj:`mmcv.ConfigDict` | dict): Training config of
|
| 54 |
+
Mask2Former head.
|
| 55 |
+
test_cfg (:obj:`mmcv.ConfigDict` | dict): Testing config of
|
| 56 |
+
Mask2Former head.
|
| 57 |
+
init_cfg (dict or list[dict], optional): Initialization config dict.
|
| 58 |
+
Defaults to None.
|
| 59 |
+
"""
|
| 60 |
+
|
| 61 |
+
def __init__(
|
| 62 |
+
self,
|
| 63 |
+
in_channels,
|
| 64 |
+
feat_channels,
|
| 65 |
+
out_channels,
|
| 66 |
+
num_things_classes=80,
|
| 67 |
+
num_stuff_classes=53,
|
| 68 |
+
num_queries=100,
|
| 69 |
+
num_transformer_feat_level=3,
|
| 70 |
+
pixel_decoder=None,
|
| 71 |
+
enforce_decoder_input_project=False,
|
| 72 |
+
transformer_decoder=None,
|
| 73 |
+
positional_encoding=None,
|
| 74 |
+
loss_cls=None,
|
| 75 |
+
loss_mask=None,
|
| 76 |
+
loss_dice=None,
|
| 77 |
+
train_cfg=None,
|
| 78 |
+
test_cfg=None,
|
| 79 |
+
init_cfg=None,
|
| 80 |
+
**kwargs,
|
| 81 |
+
):
|
| 82 |
+
super(Mask2FormerHead, self).__init__(
|
| 83 |
+
in_channels=in_channels,
|
| 84 |
+
channels=feat_channels,
|
| 85 |
+
num_classes=(num_things_classes + num_stuff_classes),
|
| 86 |
+
init_cfg=init_cfg,
|
| 87 |
+
input_transform="multiple_select",
|
| 88 |
+
**kwargs,
|
| 89 |
+
)
|
| 90 |
+
self.num_things_classes = num_things_classes
|
| 91 |
+
self.num_stuff_classes = num_stuff_classes
|
| 92 |
+
self.num_classes = self.num_things_classes + self.num_stuff_classes
|
| 93 |
+
self.num_queries = num_queries
|
| 94 |
+
self.num_transformer_feat_level = num_transformer_feat_level
|
| 95 |
+
self.num_heads = transformer_decoder.transformerlayers.attn_cfgs.num_heads
|
| 96 |
+
self.num_transformer_decoder_layers = transformer_decoder.num_layers
|
| 97 |
+
assert pixel_decoder.encoder.transformerlayers.attn_cfgs.num_levels == num_transformer_feat_level
|
| 98 |
+
pixel_decoder_ = copy.deepcopy(pixel_decoder)
|
| 99 |
+
pixel_decoder_.update(in_channels=in_channels, feat_channels=feat_channels, out_channels=out_channels)
|
| 100 |
+
self.pixel_decoder = build_plugin_layer(pixel_decoder_)[1]
|
| 101 |
+
self.transformer_decoder = build_transformer_layer_sequence(transformer_decoder)
|
| 102 |
+
self.decoder_embed_dims = self.transformer_decoder.embed_dims
|
| 103 |
+
|
| 104 |
+
self.decoder_input_projs = ModuleList()
|
| 105 |
+
# from low resolution to high resolution
|
| 106 |
+
for _ in range(num_transformer_feat_level):
|
| 107 |
+
if self.decoder_embed_dims != feat_channels or enforce_decoder_input_project:
|
| 108 |
+
self.decoder_input_projs.append(Conv2d(feat_channels, self.decoder_embed_dims, kernel_size=1))
|
| 109 |
+
else:
|
| 110 |
+
self.decoder_input_projs.append(nn.Identity())
|
| 111 |
+
self.decoder_positional_encoding = build_positional_encoding(positional_encoding)
|
| 112 |
+
self.query_embed = nn.Embedding(self.num_queries, feat_channels)
|
| 113 |
+
self.query_feat = nn.Embedding(self.num_queries, feat_channels)
|
| 114 |
+
# from low resolution to high resolution
|
| 115 |
+
self.level_embed = nn.Embedding(self.num_transformer_feat_level, feat_channels)
|
| 116 |
+
|
| 117 |
+
self.cls_embed = nn.Linear(feat_channels, self.num_classes + 1)
|
| 118 |
+
self.mask_embed = nn.Sequential(
|
| 119 |
+
nn.Linear(feat_channels, feat_channels),
|
| 120 |
+
nn.ReLU(inplace=True),
|
| 121 |
+
nn.Linear(feat_channels, feat_channels),
|
| 122 |
+
nn.ReLU(inplace=True),
|
| 123 |
+
nn.Linear(feat_channels, out_channels),
|
| 124 |
+
)
|
| 125 |
+
self.conv_seg = None # fix a bug here (conv_seg is not used)
|
| 126 |
+
|
| 127 |
+
self.test_cfg = test_cfg
|
| 128 |
+
self.train_cfg = train_cfg
|
| 129 |
+
if train_cfg:
|
| 130 |
+
self.assigner = build_assigner(self.train_cfg.assigner)
|
| 131 |
+
self.sampler = build_sampler(self.train_cfg.sampler, context=self)
|
| 132 |
+
self.num_points = self.train_cfg.get("num_points", 12544)
|
| 133 |
+
self.oversample_ratio = self.train_cfg.get("oversample_ratio", 3.0)
|
| 134 |
+
self.importance_sample_ratio = self.train_cfg.get("importance_sample_ratio", 0.75)
|
| 135 |
+
|
| 136 |
+
self.class_weight = loss_cls.class_weight
|
| 137 |
+
self.loss_cls = build_loss(loss_cls)
|
| 138 |
+
self.loss_mask = build_loss(loss_mask)
|
| 139 |
+
self.loss_dice = build_loss(loss_dice)
|
| 140 |
+
|
| 141 |
+
def init_weights(self):
|
| 142 |
+
for m in self.decoder_input_projs:
|
| 143 |
+
if isinstance(m, Conv2d):
|
| 144 |
+
caffe2_xavier_init(m, bias=0)
|
| 145 |
+
|
| 146 |
+
self.pixel_decoder.init_weights()
|
| 147 |
+
|
| 148 |
+
for p in self.transformer_decoder.parameters():
|
| 149 |
+
if p.dim() > 1:
|
| 150 |
+
nn.init.xavier_normal_(p)
|
| 151 |
+
|
| 152 |
+
def get_targets(self, cls_scores_list, mask_preds_list, gt_labels_list, gt_masks_list, img_metas):
|
| 153 |
+
"""Compute classification and mask targets for all images for a decoder
|
| 154 |
+
layer.
|
| 155 |
+
|
| 156 |
+
Args:
|
| 157 |
+
cls_scores_list (list[Tensor]): Mask score logits from a single
|
| 158 |
+
decoder layer for all images. Each with shape [num_queries,
|
| 159 |
+
cls_out_channels].
|
| 160 |
+
mask_preds_list (list[Tensor]): Mask logits from a single decoder
|
| 161 |
+
layer for all images. Each with shape [num_queries, h, w].
|
| 162 |
+
gt_labels_list (list[Tensor]): Ground truth class indices for all
|
| 163 |
+
images. Each with shape (n, ), n is the sum of number of stuff
|
| 164 |
+
type and number of instance in a image.
|
| 165 |
+
gt_masks_list (list[Tensor]): Ground truth mask for each image,
|
| 166 |
+
each with shape (n, h, w).
|
| 167 |
+
img_metas (list[dict]): List of image meta information.
|
| 168 |
+
|
| 169 |
+
Returns:
|
| 170 |
+
tuple[list[Tensor]]: a tuple containing the following targets.
|
| 171 |
+
|
| 172 |
+
- labels_list (list[Tensor]): Labels of all images.
|
| 173 |
+
Each with shape [num_queries, ].
|
| 174 |
+
- label_weights_list (list[Tensor]): Label weights of all
|
| 175 |
+
images.Each with shape [num_queries, ].
|
| 176 |
+
- mask_targets_list (list[Tensor]): Mask targets of all images.
|
| 177 |
+
Each with shape [num_queries, h, w].
|
| 178 |
+
- mask_weights_list (list[Tensor]): Mask weights of all images.
|
| 179 |
+
Each with shape [num_queries, ].
|
| 180 |
+
- num_total_pos (int): Number of positive samples in all
|
| 181 |
+
images.
|
| 182 |
+
- num_total_neg (int): Number of negative samples in all
|
| 183 |
+
images.
|
| 184 |
+
"""
|
| 185 |
+
(
|
| 186 |
+
labels_list,
|
| 187 |
+
label_weights_list,
|
| 188 |
+
mask_targets_list,
|
| 189 |
+
mask_weights_list,
|
| 190 |
+
pos_inds_list,
|
| 191 |
+
neg_inds_list,
|
| 192 |
+
) = multi_apply(
|
| 193 |
+
self._get_target_single, cls_scores_list, mask_preds_list, gt_labels_list, gt_masks_list, img_metas
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
num_total_pos = sum((inds.numel() for inds in pos_inds_list))
|
| 197 |
+
num_total_neg = sum((inds.numel() for inds in neg_inds_list))
|
| 198 |
+
return (labels_list, label_weights_list, mask_targets_list, mask_weights_list, num_total_pos, num_total_neg)
|
| 199 |
+
|
| 200 |
+
def _get_target_single(self, cls_score, mask_pred, gt_labels, gt_masks, img_metas):
|
| 201 |
+
"""Compute classification and mask targets for one image.
|
| 202 |
+
|
| 203 |
+
Args:
|
| 204 |
+
cls_score (Tensor): Mask score logits from a single decoder layer
|
| 205 |
+
for one image. Shape (num_queries, cls_out_channels).
|
| 206 |
+
mask_pred (Tensor): Mask logits for a single decoder layer for one
|
| 207 |
+
image. Shape (num_queries, h, w).
|
| 208 |
+
gt_labels (Tensor): Ground truth class indices for one image with
|
| 209 |
+
shape (num_gts, ).
|
| 210 |
+
gt_masks (Tensor): Ground truth mask for each image, each with
|
| 211 |
+
shape (num_gts, h, w).
|
| 212 |
+
img_metas (dict): Image informtation.
|
| 213 |
+
|
| 214 |
+
Returns:
|
| 215 |
+
tuple[Tensor]: A tuple containing the following for one image.
|
| 216 |
+
|
| 217 |
+
- labels (Tensor): Labels of each image. \
|
| 218 |
+
shape (num_queries, ).
|
| 219 |
+
- label_weights (Tensor): Label weights of each image. \
|
| 220 |
+
shape (num_queries, ).
|
| 221 |
+
- mask_targets (Tensor): Mask targets of each image. \
|
| 222 |
+
shape (num_queries, h, w).
|
| 223 |
+
- mask_weights (Tensor): Mask weights of each image. \
|
| 224 |
+
shape (num_queries, ).
|
| 225 |
+
- pos_inds (Tensor): Sampled positive indices for each \
|
| 226 |
+
image.
|
| 227 |
+
- neg_inds (Tensor): Sampled negative indices for each \
|
| 228 |
+
image.
|
| 229 |
+
"""
|
| 230 |
+
# sample points
|
| 231 |
+
num_queries = cls_score.shape[0]
|
| 232 |
+
num_gts = gt_labels.shape[0]
|
| 233 |
+
|
| 234 |
+
point_coords = torch.rand((1, self.num_points, 2), device=cls_score.device)
|
| 235 |
+
# shape (num_queries, num_points)
|
| 236 |
+
mask_points_pred = point_sample(mask_pred.unsqueeze(1), point_coords.repeat(num_queries, 1, 1)).squeeze(1)
|
| 237 |
+
# shape (num_gts, num_points)
|
| 238 |
+
gt_points_masks = point_sample(gt_masks.unsqueeze(1).float(), point_coords.repeat(num_gts, 1, 1)).squeeze(1)
|
| 239 |
+
|
| 240 |
+
# assign and sample
|
| 241 |
+
assign_result = self.assigner.assign(cls_score, mask_points_pred, gt_labels, gt_points_masks, img_metas)
|
| 242 |
+
sampling_result = self.sampler.sample(assign_result, mask_pred, gt_masks)
|
| 243 |
+
pos_inds = sampling_result.pos_inds
|
| 244 |
+
neg_inds = sampling_result.neg_inds
|
| 245 |
+
|
| 246 |
+
# label target
|
| 247 |
+
labels = gt_labels.new_full((self.num_queries,), self.num_classes, dtype=torch.long)
|
| 248 |
+
labels[pos_inds] = gt_labels[sampling_result.pos_assigned_gt_inds]
|
| 249 |
+
label_weights = gt_labels.new_ones((self.num_queries,))
|
| 250 |
+
|
| 251 |
+
# mask target
|
| 252 |
+
mask_targets = gt_masks[sampling_result.pos_assigned_gt_inds]
|
| 253 |
+
mask_weights = mask_pred.new_zeros((self.num_queries,))
|
| 254 |
+
mask_weights[pos_inds] = 1.0
|
| 255 |
+
|
| 256 |
+
return (labels, label_weights, mask_targets, mask_weights, pos_inds, neg_inds)
|
| 257 |
+
|
| 258 |
+
def loss_single(self, cls_scores, mask_preds, gt_labels_list, gt_masks_list, img_metas):
|
| 259 |
+
"""Loss function for outputs from a single decoder layer.
|
| 260 |
+
|
| 261 |
+
Args:
|
| 262 |
+
cls_scores (Tensor): Mask score logits from a single decoder layer
|
| 263 |
+
for all images. Shape (batch_size, num_queries,
|
| 264 |
+
cls_out_channels). Note `cls_out_channels` should includes
|
| 265 |
+
background.
|
| 266 |
+
mask_preds (Tensor): Mask logits for a pixel decoder for all
|
| 267 |
+
images. Shape (batch_size, num_queries, h, w).
|
| 268 |
+
gt_labels_list (list[Tensor]): Ground truth class indices for each
|
| 269 |
+
image, each with shape (num_gts, ).
|
| 270 |
+
gt_masks_list (list[Tensor]): Ground truth mask for each image,
|
| 271 |
+
each with shape (num_gts, h, w).
|
| 272 |
+
img_metas (list[dict]): List of image meta information.
|
| 273 |
+
|
| 274 |
+
Returns:
|
| 275 |
+
tuple[Tensor]: Loss components for outputs from a single \
|
| 276 |
+
decoder layer.
|
| 277 |
+
"""
|
| 278 |
+
num_imgs = cls_scores.size(0)
|
| 279 |
+
cls_scores_list = [cls_scores[i] for i in range(num_imgs)]
|
| 280 |
+
mask_preds_list = [mask_preds[i] for i in range(num_imgs)]
|
| 281 |
+
(
|
| 282 |
+
labels_list,
|
| 283 |
+
label_weights_list,
|
| 284 |
+
mask_targets_list,
|
| 285 |
+
mask_weights_list,
|
| 286 |
+
num_total_pos,
|
| 287 |
+
num_total_neg,
|
| 288 |
+
) = self.get_targets(cls_scores_list, mask_preds_list, gt_labels_list, gt_masks_list, img_metas)
|
| 289 |
+
# shape (batch_size, num_queries)
|
| 290 |
+
labels = torch.stack(labels_list, dim=0)
|
| 291 |
+
# shape (batch_size, num_queries)
|
| 292 |
+
label_weights = torch.stack(label_weights_list, dim=0)
|
| 293 |
+
# shape (num_total_gts, h, w)
|
| 294 |
+
mask_targets = torch.cat(mask_targets_list, dim=0)
|
| 295 |
+
# shape (batch_size, num_queries)
|
| 296 |
+
mask_weights = torch.stack(mask_weights_list, dim=0)
|
| 297 |
+
|
| 298 |
+
# classfication loss
|
| 299 |
+
# shape (batch_size * num_queries, )
|
| 300 |
+
cls_scores = cls_scores.flatten(0, 1)
|
| 301 |
+
labels = labels.flatten(0, 1)
|
| 302 |
+
label_weights = label_weights.flatten(0, 1)
|
| 303 |
+
|
| 304 |
+
class_weight = cls_scores.new_tensor(self.class_weight)
|
| 305 |
+
loss_cls = self.loss_cls(cls_scores, labels, label_weights, avg_factor=class_weight[labels].sum())
|
| 306 |
+
|
| 307 |
+
num_total_masks = reduce_mean(cls_scores.new_tensor([num_total_pos]))
|
| 308 |
+
num_total_masks = max(num_total_masks, 1)
|
| 309 |
+
|
| 310 |
+
# extract positive ones
|
| 311 |
+
# shape (batch_size, num_queries, h, w) -> (num_total_gts, h, w)
|
| 312 |
+
mask_preds = mask_preds[mask_weights > 0]
|
| 313 |
+
|
| 314 |
+
if mask_targets.shape[0] == 0:
|
| 315 |
+
# zero match
|
| 316 |
+
loss_dice = mask_preds.sum()
|
| 317 |
+
loss_mask = mask_preds.sum()
|
| 318 |
+
return loss_cls, loss_mask, loss_dice
|
| 319 |
+
|
| 320 |
+
with torch.no_grad():
|
| 321 |
+
points_coords = get_uncertain_point_coords_with_randomness(
|
| 322 |
+
mask_preds.unsqueeze(1), None, self.num_points, self.oversample_ratio, self.importance_sample_ratio
|
| 323 |
+
)
|
| 324 |
+
# shape (num_total_gts, h, w) -> (num_total_gts, num_points)
|
| 325 |
+
mask_point_targets = point_sample(mask_targets.unsqueeze(1).float(), points_coords).squeeze(1)
|
| 326 |
+
# shape (num_queries, h, w) -> (num_queries, num_points)
|
| 327 |
+
mask_point_preds = point_sample(mask_preds.unsqueeze(1), points_coords).squeeze(1)
|
| 328 |
+
|
| 329 |
+
# dice loss
|
| 330 |
+
loss_dice = self.loss_dice(mask_point_preds, mask_point_targets, avg_factor=num_total_masks)
|
| 331 |
+
|
| 332 |
+
# mask loss
|
| 333 |
+
# shape (num_queries, num_points) -> (num_queries * num_points, )
|
| 334 |
+
mask_point_preds = mask_point_preds.reshape(-1, 1)
|
| 335 |
+
# shape (num_total_gts, num_points) -> (num_total_gts * num_points, )
|
| 336 |
+
mask_point_targets = mask_point_targets.reshape(-1)
|
| 337 |
+
loss_mask = self.loss_mask(mask_point_preds, mask_point_targets, avg_factor=num_total_masks * self.num_points)
|
| 338 |
+
|
| 339 |
+
return loss_cls, loss_mask, loss_dice
|
| 340 |
+
|
| 341 |
+
@force_fp32(apply_to=("all_cls_scores", "all_mask_preds"))
|
| 342 |
+
def loss(self, all_cls_scores, all_mask_preds, gt_labels_list, gt_masks_list, img_metas):
|
| 343 |
+
"""Loss function.
|
| 344 |
+
|
| 345 |
+
Args:
|
| 346 |
+
all_cls_scores (Tensor): Classification scores for all decoder
|
| 347 |
+
layers with shape [num_decoder, batch_size, num_queries,
|
| 348 |
+
cls_out_channels].
|
| 349 |
+
all_mask_preds (Tensor): Mask scores for all decoder layers with
|
| 350 |
+
shape [num_decoder, batch_size, num_queries, h, w].
|
| 351 |
+
gt_labels_list (list[Tensor]): Ground truth class indices for each
|
| 352 |
+
image with shape (n, ). n is the sum of number of stuff type
|
| 353 |
+
and number of instance in a image.
|
| 354 |
+
gt_masks_list (list[Tensor]): Ground truth mask for each image with
|
| 355 |
+
shape (n, h, w).
|
| 356 |
+
img_metas (list[dict]): List of image meta information.
|
| 357 |
+
|
| 358 |
+
Returns:
|
| 359 |
+
dict[str, Tensor]: A dictionary of loss components.
|
| 360 |
+
"""
|
| 361 |
+
num_dec_layers = len(all_cls_scores)
|
| 362 |
+
all_gt_labels_list = [gt_labels_list for _ in range(num_dec_layers)]
|
| 363 |
+
all_gt_masks_list = [gt_masks_list for _ in range(num_dec_layers)]
|
| 364 |
+
img_metas_list = [img_metas for _ in range(num_dec_layers)]
|
| 365 |
+
losses_cls, losses_mask, losses_dice = multi_apply(
|
| 366 |
+
self.loss_single, all_cls_scores, all_mask_preds, all_gt_labels_list, all_gt_masks_list, img_metas_list
|
| 367 |
+
)
|
| 368 |
+
|
| 369 |
+
loss_dict = dict()
|
| 370 |
+
# loss from the last decoder layer
|
| 371 |
+
loss_dict["loss_cls"] = losses_cls[-1]
|
| 372 |
+
loss_dict["loss_mask"] = losses_mask[-1]
|
| 373 |
+
loss_dict["loss_dice"] = losses_dice[-1]
|
| 374 |
+
# loss from other decoder layers
|
| 375 |
+
num_dec_layer = 0
|
| 376 |
+
for loss_cls_i, loss_mask_i, loss_dice_i in zip(losses_cls[:-1], losses_mask[:-1], losses_dice[:-1]):
|
| 377 |
+
loss_dict[f"d{num_dec_layer}.loss_cls"] = loss_cls_i
|
| 378 |
+
loss_dict[f"d{num_dec_layer}.loss_mask"] = loss_mask_i
|
| 379 |
+
loss_dict[f"d{num_dec_layer}.loss_dice"] = loss_dice_i
|
| 380 |
+
num_dec_layer += 1
|
| 381 |
+
return loss_dict
|
| 382 |
+
|
| 383 |
+
def forward_head(self, decoder_out, mask_feature, attn_mask_target_size):
|
| 384 |
+
"""Forward for head part which is called after every decoder layer.
|
| 385 |
+
|
| 386 |
+
Args:
|
| 387 |
+
decoder_out (Tensor): in shape (num_queries, batch_size, c).
|
| 388 |
+
mask_feature (Tensor): in shape (batch_size, c, h, w).
|
| 389 |
+
attn_mask_target_size (tuple[int, int]): target attention
|
| 390 |
+
mask size.
|
| 391 |
+
|
| 392 |
+
Returns:
|
| 393 |
+
tuple: A tuple contain three elements.
|
| 394 |
+
|
| 395 |
+
- cls_pred (Tensor): Classification scores in shape \
|
| 396 |
+
(batch_size, num_queries, cls_out_channels). \
|
| 397 |
+
Note `cls_out_channels` should includes background.
|
| 398 |
+
- mask_pred (Tensor): Mask scores in shape \
|
| 399 |
+
(batch_size, num_queries,h, w).
|
| 400 |
+
- attn_mask (Tensor): Attention mask in shape \
|
| 401 |
+
(batch_size * num_heads, num_queries, h, w).
|
| 402 |
+
"""
|
| 403 |
+
decoder_out = self.transformer_decoder.post_norm(decoder_out)
|
| 404 |
+
decoder_out = decoder_out.transpose(0, 1)
|
| 405 |
+
# shape (num_queries, batch_size, c)
|
| 406 |
+
cls_pred = self.cls_embed(decoder_out)
|
| 407 |
+
# shape (num_queries, batch_size, c)
|
| 408 |
+
mask_embed = self.mask_embed(decoder_out)
|
| 409 |
+
# shape (num_queries, batch_size, h, w)
|
| 410 |
+
mask_pred = torch.einsum("bqc,bchw->bqhw", mask_embed, mask_feature)
|
| 411 |
+
attn_mask = F.interpolate(mask_pred, attn_mask_target_size, mode="bilinear", align_corners=False)
|
| 412 |
+
# shape (num_queries, batch_size, h, w) ->
|
| 413 |
+
# (batch_size * num_head, num_queries, h, w)
|
| 414 |
+
attn_mask = attn_mask.flatten(2).unsqueeze(1).repeat((1, self.num_heads, 1, 1)).flatten(0, 1)
|
| 415 |
+
attn_mask = attn_mask.sigmoid() < 0.5
|
| 416 |
+
attn_mask = attn_mask.detach()
|
| 417 |
+
|
| 418 |
+
return cls_pred, mask_pred, attn_mask
|
| 419 |
+
|
| 420 |
+
def forward(self, feats, img_metas):
|
| 421 |
+
"""Forward function.
|
| 422 |
+
|
| 423 |
+
Args:
|
| 424 |
+
feats (list[Tensor]): Multi scale Features from the
|
| 425 |
+
upstream network, each is a 4D-tensor.
|
| 426 |
+
img_metas (list[dict]): List of image information.
|
| 427 |
+
|
| 428 |
+
Returns:
|
| 429 |
+
tuple: A tuple contains two elements.
|
| 430 |
+
|
| 431 |
+
- cls_pred_list (list[Tensor)]: Classification logits \
|
| 432 |
+
for each decoder layer. Each is a 3D-tensor with shape \
|
| 433 |
+
(batch_size, num_queries, cls_out_channels). \
|
| 434 |
+
Note `cls_out_channels` should includes background.
|
| 435 |
+
- mask_pred_list (list[Tensor]): Mask logits for each \
|
| 436 |
+
decoder layer. Each with shape (batch_size, num_queries, \
|
| 437 |
+
h, w).
|
| 438 |
+
"""
|
| 439 |
+
batch_size = len(img_metas)
|
| 440 |
+
mask_features, multi_scale_memorys = self.pixel_decoder(feats)
|
| 441 |
+
# multi_scale_memorys (from low resolution to high resolution)
|
| 442 |
+
decoder_inputs = []
|
| 443 |
+
decoder_positional_encodings = []
|
| 444 |
+
for i in range(self.num_transformer_feat_level):
|
| 445 |
+
decoder_input = self.decoder_input_projs[i](multi_scale_memorys[i])
|
| 446 |
+
# shape (batch_size, c, h, w) -> (h*w, batch_size, c)
|
| 447 |
+
decoder_input = decoder_input.flatten(2).permute(2, 0, 1)
|
| 448 |
+
level_embed = self.level_embed.weight[i].view(1, 1, -1)
|
| 449 |
+
decoder_input = decoder_input + level_embed
|
| 450 |
+
# shape (batch_size, c, h, w) -> (h*w, batch_size, c)
|
| 451 |
+
mask = decoder_input.new_zeros((batch_size,) + multi_scale_memorys[i].shape[-2:], dtype=torch.bool)
|
| 452 |
+
decoder_positional_encoding = self.decoder_positional_encoding(mask)
|
| 453 |
+
decoder_positional_encoding = decoder_positional_encoding.flatten(2).permute(2, 0, 1)
|
| 454 |
+
decoder_inputs.append(decoder_input)
|
| 455 |
+
decoder_positional_encodings.append(decoder_positional_encoding)
|
| 456 |
+
# shape (num_queries, c) -> (num_queries, batch_size, c)
|
| 457 |
+
query_feat = self.query_feat.weight.unsqueeze(1).repeat((1, batch_size, 1))
|
| 458 |
+
query_embed = self.query_embed.weight.unsqueeze(1).repeat((1, batch_size, 1))
|
| 459 |
+
|
| 460 |
+
cls_pred_list = []
|
| 461 |
+
mask_pred_list = []
|
| 462 |
+
cls_pred, mask_pred, attn_mask = self.forward_head(query_feat, mask_features, multi_scale_memorys[0].shape[-2:])
|
| 463 |
+
cls_pred_list.append(cls_pred)
|
| 464 |
+
mask_pred_list.append(mask_pred)
|
| 465 |
+
|
| 466 |
+
for i in range(self.num_transformer_decoder_layers):
|
| 467 |
+
level_idx = i % self.num_transformer_feat_level
|
| 468 |
+
# if a mask is all True(all background), then set it all False.
|
| 469 |
+
attn_mask[torch.where(attn_mask.sum(-1) == attn_mask.shape[-1])] = False
|
| 470 |
+
|
| 471 |
+
# cross_attn + self_attn
|
| 472 |
+
layer = self.transformer_decoder.layers[i]
|
| 473 |
+
attn_masks = [attn_mask, None]
|
| 474 |
+
query_feat = layer(
|
| 475 |
+
query=query_feat,
|
| 476 |
+
key=decoder_inputs[level_idx],
|
| 477 |
+
value=decoder_inputs[level_idx],
|
| 478 |
+
query_pos=query_embed,
|
| 479 |
+
key_pos=decoder_positional_encodings[level_idx],
|
| 480 |
+
attn_masks=attn_masks,
|
| 481 |
+
query_key_padding_mask=None,
|
| 482 |
+
# here we do not apply masking on padded region
|
| 483 |
+
key_padding_mask=None,
|
| 484 |
+
)
|
| 485 |
+
cls_pred, mask_pred, attn_mask = self.forward_head(
|
| 486 |
+
query_feat, mask_features, multi_scale_memorys[(i + 1) % self.num_transformer_feat_level].shape[-2:]
|
| 487 |
+
)
|
| 488 |
+
|
| 489 |
+
cls_pred_list.append(cls_pred)
|
| 490 |
+
mask_pred_list.append(mask_pred)
|
| 491 |
+
|
| 492 |
+
return cls_pred_list, mask_pred_list
|
| 493 |
+
|
| 494 |
+
def forward_train(self, x, img_metas, gt_semantic_seg, gt_labels, gt_masks):
|
| 495 |
+
"""Forward function for training mode.
|
| 496 |
+
|
| 497 |
+
Args:
|
| 498 |
+
x (list[Tensor]): Multi-level features from the upstream network,
|
| 499 |
+
each is a 4D-tensor.
|
| 500 |
+
img_metas (list[Dict]): List of image information.
|
| 501 |
+
gt_semantic_seg (list[tensor]):Each element is the ground truth
|
| 502 |
+
of semantic segmentation with the shape (N, H, W).
|
| 503 |
+
train_cfg (dict): The training config, which not been used in
|
| 504 |
+
maskformer.
|
| 505 |
+
gt_labels (list[Tensor]): Each element is ground truth labels of
|
| 506 |
+
each box, shape (num_gts,).
|
| 507 |
+
gt_masks (list[BitmapMasks]): Each element is masks of instances
|
| 508 |
+
of a image, shape (num_gts, h, w).
|
| 509 |
+
|
| 510 |
+
Returns:
|
| 511 |
+
losses (dict[str, Tensor]): a dictionary of loss components
|
| 512 |
+
"""
|
| 513 |
+
|
| 514 |
+
# forward
|
| 515 |
+
all_cls_scores, all_mask_preds = self(x, img_metas)
|
| 516 |
+
|
| 517 |
+
# loss
|
| 518 |
+
losses = self.loss(all_cls_scores, all_mask_preds, gt_labels, gt_masks, img_metas)
|
| 519 |
+
|
| 520 |
+
return losses
|
| 521 |
+
|
| 522 |
+
def forward_test(self, inputs, img_metas, test_cfg):
|
| 523 |
+
"""Test segment without test-time aumengtation.
|
| 524 |
+
|
| 525 |
+
Only the output of last decoder layers was used.
|
| 526 |
+
|
| 527 |
+
Args:
|
| 528 |
+
inputs (list[Tensor]): Multi-level features from the
|
| 529 |
+
upstream network, each is a 4D-tensor.
|
| 530 |
+
img_metas (list[dict]): List of image information.
|
| 531 |
+
test_cfg (dict): Testing config.
|
| 532 |
+
|
| 533 |
+
Returns:
|
| 534 |
+
seg_mask (Tensor): Predicted semantic segmentation logits.
|
| 535 |
+
"""
|
| 536 |
+
all_cls_scores, all_mask_preds = self(inputs, img_metas)
|
| 537 |
+
cls_score, mask_pred = all_cls_scores[-1], all_mask_preds[-1]
|
| 538 |
+
ori_h, ori_w, _ = img_metas[0]["ori_shape"]
|
| 539 |
+
|
| 540 |
+
# semantic inference
|
| 541 |
+
cls_score = F.softmax(cls_score, dim=-1)[..., :-1]
|
| 542 |
+
mask_pred = mask_pred.sigmoid()
|
| 543 |
+
seg_mask = torch.einsum("bqc,bqhw->bchw", cls_score, mask_pred)
|
| 544 |
+
return seg_mask
|
torch_hub/facebookresearch_dinov2_main/dinov2/eval/segmentation_m2f/models/utils/point_sample.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from mmcv.ops import point_sample
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def get_uncertainty(mask_pred, labels):
|
| 11 |
+
"""Estimate uncertainty based on pred logits.
|
| 12 |
+
|
| 13 |
+
We estimate uncertainty as L1 distance between 0.0 and the logits
|
| 14 |
+
prediction in 'mask_pred' for the foreground class in `classes`.
|
| 15 |
+
|
| 16 |
+
Args:
|
| 17 |
+
mask_pred (Tensor): mask predication logits, shape (num_rois,
|
| 18 |
+
num_classes, mask_height, mask_width).
|
| 19 |
+
|
| 20 |
+
labels (list[Tensor]): Either predicted or ground truth label for
|
| 21 |
+
each predicted mask, of length num_rois.
|
| 22 |
+
|
| 23 |
+
Returns:
|
| 24 |
+
scores (Tensor): Uncertainty scores with the most uncertain
|
| 25 |
+
locations having the highest uncertainty score,
|
| 26 |
+
shape (num_rois, 1, mask_height, mask_width)
|
| 27 |
+
"""
|
| 28 |
+
if mask_pred.shape[1] == 1:
|
| 29 |
+
gt_class_logits = mask_pred.clone()
|
| 30 |
+
else:
|
| 31 |
+
inds = torch.arange(mask_pred.shape[0], device=mask_pred.device)
|
| 32 |
+
gt_class_logits = mask_pred[inds, labels].unsqueeze(1)
|
| 33 |
+
return -torch.abs(gt_class_logits)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def get_uncertain_point_coords_with_randomness(
|
| 37 |
+
mask_pred, labels, num_points, oversample_ratio, importance_sample_ratio
|
| 38 |
+
):
|
| 39 |
+
"""Get ``num_points`` most uncertain points with random points during
|
| 40 |
+
train.
|
| 41 |
+
|
| 42 |
+
Sample points in [0, 1] x [0, 1] coordinate space based on their
|
| 43 |
+
uncertainty. The uncertainties are calculated for each point using
|
| 44 |
+
'get_uncertainty()' function that takes point's logit prediction as
|
| 45 |
+
input.
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
mask_pred (Tensor): A tensor of shape (num_rois, num_classes,
|
| 49 |
+
mask_height, mask_width) for class-specific or class-agnostic
|
| 50 |
+
prediction.
|
| 51 |
+
labels (list): The ground truth class for each instance.
|
| 52 |
+
num_points (int): The number of points to sample.
|
| 53 |
+
oversample_ratio (int): Oversampling parameter.
|
| 54 |
+
importance_sample_ratio (float): Ratio of points that are sampled
|
| 55 |
+
via importnace sampling.
|
| 56 |
+
|
| 57 |
+
Returns:
|
| 58 |
+
point_coords (Tensor): A tensor of shape (num_rois, num_points, 2)
|
| 59 |
+
that contains the coordinates sampled points.
|
| 60 |
+
"""
|
| 61 |
+
assert oversample_ratio >= 1
|
| 62 |
+
assert 0 <= importance_sample_ratio <= 1
|
| 63 |
+
batch_size = mask_pred.shape[0]
|
| 64 |
+
num_sampled = int(num_points * oversample_ratio)
|
| 65 |
+
point_coords = torch.rand(batch_size, num_sampled, 2, device=mask_pred.device)
|
| 66 |
+
point_logits = point_sample(mask_pred, point_coords)
|
| 67 |
+
# It is crucial to calculate uncertainty based on the sampled
|
| 68 |
+
# prediction value for the points. Calculating uncertainties of the
|
| 69 |
+
# coarse predictions first and sampling them for points leads to
|
| 70 |
+
# incorrect results. To illustrate this: assume uncertainty func(
|
| 71 |
+
# logits)=-abs(logits), a sampled point between two coarse
|
| 72 |
+
# predictions with -1 and 1 logits has 0 logits, and therefore 0
|
| 73 |
+
# uncertainty value. However, if we calculate uncertainties for the
|
| 74 |
+
# coarse predictions first, both will have -1 uncertainty,
|
| 75 |
+
# and sampled point will get -1 uncertainty.
|
| 76 |
+
point_uncertainties = get_uncertainty(point_logits, labels)
|
| 77 |
+
num_uncertain_points = int(importance_sample_ratio * num_points)
|
| 78 |
+
num_random_points = num_points - num_uncertain_points
|
| 79 |
+
idx = torch.topk(point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1]
|
| 80 |
+
shift = num_sampled * torch.arange(batch_size, dtype=torch.long, device=mask_pred.device)
|
| 81 |
+
idx += shift[:, None]
|
| 82 |
+
point_coords = point_coords.view(-1, 2)[idx.view(-1), :].view(batch_size, num_uncertain_points, 2)
|
| 83 |
+
if num_random_points > 0:
|
| 84 |
+
rand_roi_coords = torch.rand(batch_size, num_random_points, 2, device=mask_pred.device)
|
| 85 |
+
point_coords = torch.cat((point_coords, rand_roi_coords), dim=1)
|
| 86 |
+
return point_coords
|
torch_hub/facebookresearch_dinov2_main/dinov2/hub/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (167 Bytes). View file
|
|
|
torch_hub/facebookresearch_dinov2_main/dinov2/hub/__pycache__/backbones.cpython-310.pyc
ADDED
|
Binary file (4.6 kB). View file
|
|
|
torch_hub/facebookresearch_dinov2_main/dinov2/hub/__pycache__/classifiers.cpython-310.pyc
ADDED
|
Binary file (6.32 kB). View file
|
|
|
torch_hub/facebookresearch_dinov2_main/dinov2/hub/__pycache__/dinotxt.cpython-310.pyc
ADDED
|
Binary file (2.78 kB). View file
|
|
|
torch_hub/facebookresearch_dinov2_main/dinov2/hub/__pycache__/utils.cpython-310.pyc
ADDED
|
Binary file (1.78 kB). View file
|
|
|
torch_hub/facebookresearch_dinov2_main/dinov2/hub/depth/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from .decode_heads import BNHead, DPTHead
|
| 7 |
+
from .encoder_decoder import DepthEncoderDecoder
|
torch_hub/facebookresearch_dinov2_main/dinov2/hub/depth/__pycache__/ops.cpython-310.pyc
ADDED
|
Binary file (1.07 kB). View file
|
|
|
torch_hub/facebookresearch_dinov2_main/dinov2/hub/depth/decode_heads.py
ADDED
|
@@ -0,0 +1,747 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import copy
|
| 7 |
+
from functools import partial
|
| 8 |
+
import math
|
| 9 |
+
import warnings
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
|
| 14 |
+
from .ops import resize
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
# XXX: (Untested) replacement for mmcv.imdenormalize()
|
| 18 |
+
def _imdenormalize(img, mean, std, to_bgr=True):
|
| 19 |
+
import numpy as np
|
| 20 |
+
|
| 21 |
+
mean = mean.reshape(1, -1).astype(np.float64)
|
| 22 |
+
std = std.reshape(1, -1).astype(np.float64)
|
| 23 |
+
img = (img * std) + mean
|
| 24 |
+
if to_bgr:
|
| 25 |
+
img = img[::-1]
|
| 26 |
+
return img
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class DepthBaseDecodeHead(nn.Module):
|
| 30 |
+
"""Base class for BaseDecodeHead.
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
in_channels (List): Input channels.
|
| 34 |
+
channels (int): Channels after modules, before conv_depth.
|
| 35 |
+
conv_layer (nn.Module): Conv layers. Default: None.
|
| 36 |
+
act_layer (nn.Module): Activation layers. Default: nn.ReLU.
|
| 37 |
+
loss_decode (dict): Config of decode loss.
|
| 38 |
+
Default: ().
|
| 39 |
+
sampler (dict|None): The config of depth map sampler.
|
| 40 |
+
Default: None.
|
| 41 |
+
align_corners (bool): align_corners argument of F.interpolate.
|
| 42 |
+
Default: False.
|
| 43 |
+
min_depth (int): Min depth in dataset setting.
|
| 44 |
+
Default: 1e-3.
|
| 45 |
+
max_depth (int): Max depth in dataset setting.
|
| 46 |
+
Default: None.
|
| 47 |
+
norm_layer (dict|None): Norm layers.
|
| 48 |
+
Default: None.
|
| 49 |
+
classify (bool): Whether predict depth in a cls.-reg. manner.
|
| 50 |
+
Default: False.
|
| 51 |
+
n_bins (int): The number of bins used in cls. step.
|
| 52 |
+
Default: 256.
|
| 53 |
+
bins_strategy (str): The discrete strategy used in cls. step.
|
| 54 |
+
Default: 'UD'.
|
| 55 |
+
norm_strategy (str): The norm strategy on cls. probability
|
| 56 |
+
distribution. Default: 'linear'
|
| 57 |
+
scale_up (str): Whether predict depth in a scale-up manner.
|
| 58 |
+
Default: False.
|
| 59 |
+
"""
|
| 60 |
+
|
| 61 |
+
def __init__(
|
| 62 |
+
self,
|
| 63 |
+
in_channels,
|
| 64 |
+
conv_layer=None,
|
| 65 |
+
act_layer=nn.ReLU,
|
| 66 |
+
channels=96,
|
| 67 |
+
loss_decode=(),
|
| 68 |
+
sampler=None,
|
| 69 |
+
align_corners=False,
|
| 70 |
+
min_depth=1e-3,
|
| 71 |
+
max_depth=None,
|
| 72 |
+
norm_layer=None,
|
| 73 |
+
classify=False,
|
| 74 |
+
n_bins=256,
|
| 75 |
+
bins_strategy="UD",
|
| 76 |
+
norm_strategy="linear",
|
| 77 |
+
scale_up=False,
|
| 78 |
+
):
|
| 79 |
+
super(DepthBaseDecodeHead, self).__init__()
|
| 80 |
+
|
| 81 |
+
self.in_channels = in_channels
|
| 82 |
+
self.channels = channels
|
| 83 |
+
self.conf_layer = conv_layer
|
| 84 |
+
self.act_layer = act_layer
|
| 85 |
+
self.loss_decode = loss_decode
|
| 86 |
+
self.align_corners = align_corners
|
| 87 |
+
self.min_depth = min_depth
|
| 88 |
+
self.max_depth = max_depth
|
| 89 |
+
self.norm_layer = norm_layer
|
| 90 |
+
self.classify = classify
|
| 91 |
+
self.n_bins = n_bins
|
| 92 |
+
self.scale_up = scale_up
|
| 93 |
+
|
| 94 |
+
if self.classify:
|
| 95 |
+
assert bins_strategy in ["UD", "SID"], "Support bins_strategy: UD, SID"
|
| 96 |
+
assert norm_strategy in ["linear", "softmax", "sigmoid"], "Support norm_strategy: linear, softmax, sigmoid"
|
| 97 |
+
|
| 98 |
+
self.bins_strategy = bins_strategy
|
| 99 |
+
self.norm_strategy = norm_strategy
|
| 100 |
+
self.softmax = nn.Softmax(dim=1)
|
| 101 |
+
self.conv_depth = nn.Conv2d(channels, n_bins, kernel_size=3, padding=1, stride=1)
|
| 102 |
+
else:
|
| 103 |
+
self.conv_depth = nn.Conv2d(channels, 1, kernel_size=3, padding=1, stride=1)
|
| 104 |
+
|
| 105 |
+
self.relu = nn.ReLU()
|
| 106 |
+
self.sigmoid = nn.Sigmoid()
|
| 107 |
+
|
| 108 |
+
def forward(self, inputs, img_metas):
|
| 109 |
+
"""Placeholder of forward function."""
|
| 110 |
+
pass
|
| 111 |
+
|
| 112 |
+
def forward_train(self, img, inputs, img_metas, depth_gt):
|
| 113 |
+
"""Forward function for training.
|
| 114 |
+
Args:
|
| 115 |
+
inputs (list[Tensor]): List of multi-level img features.
|
| 116 |
+
img_metas (list[dict]): List of image info dict where each dict
|
| 117 |
+
has: 'img_shape', 'scale_factor', 'flip', and may also contain
|
| 118 |
+
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
|
| 119 |
+
For details on the values of these keys see
|
| 120 |
+
`depth/datasets/pipelines/formatting.py:Collect`.
|
| 121 |
+
depth_gt (Tensor): GT depth
|
| 122 |
+
|
| 123 |
+
Returns:
|
| 124 |
+
dict[str, Tensor]: a dictionary of loss components
|
| 125 |
+
"""
|
| 126 |
+
depth_pred = self.forward(inputs, img_metas)
|
| 127 |
+
losses = self.losses(depth_pred, depth_gt)
|
| 128 |
+
|
| 129 |
+
log_imgs = self.log_images(img[0], depth_pred[0], depth_gt[0], img_metas[0])
|
| 130 |
+
losses.update(**log_imgs)
|
| 131 |
+
|
| 132 |
+
return losses
|
| 133 |
+
|
| 134 |
+
def forward_test(self, inputs, img_metas):
|
| 135 |
+
"""Forward function for testing.
|
| 136 |
+
Args:
|
| 137 |
+
inputs (list[Tensor]): List of multi-level img features.
|
| 138 |
+
img_metas (list[dict]): List of image info dict where each dict
|
| 139 |
+
has: 'img_shape', 'scale_factor', 'flip', and may also contain
|
| 140 |
+
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
|
| 141 |
+
For details on the values of these keys see
|
| 142 |
+
`depth/datasets/pipelines/formatting.py:Collect`.
|
| 143 |
+
|
| 144 |
+
Returns:
|
| 145 |
+
Tensor: Output depth map.
|
| 146 |
+
"""
|
| 147 |
+
return self.forward(inputs, img_metas)
|
| 148 |
+
|
| 149 |
+
def depth_pred(self, feat):
|
| 150 |
+
"""Prediction each pixel."""
|
| 151 |
+
if self.classify:
|
| 152 |
+
logit = self.conv_depth(feat)
|
| 153 |
+
|
| 154 |
+
if self.bins_strategy == "UD":
|
| 155 |
+
bins = torch.linspace(self.min_depth, self.max_depth, self.n_bins, device=feat.device)
|
| 156 |
+
elif self.bins_strategy == "SID":
|
| 157 |
+
bins = torch.logspace(self.min_depth, self.max_depth, self.n_bins, device=feat.device)
|
| 158 |
+
|
| 159 |
+
# following Adabins, default linear
|
| 160 |
+
if self.norm_strategy == "linear":
|
| 161 |
+
logit = torch.relu(logit)
|
| 162 |
+
eps = 0.1
|
| 163 |
+
logit = logit + eps
|
| 164 |
+
logit = logit / logit.sum(dim=1, keepdim=True)
|
| 165 |
+
elif self.norm_strategy == "softmax":
|
| 166 |
+
logit = torch.softmax(logit, dim=1)
|
| 167 |
+
elif self.norm_strategy == "sigmoid":
|
| 168 |
+
logit = torch.sigmoid(logit)
|
| 169 |
+
logit = logit / logit.sum(dim=1, keepdim=True)
|
| 170 |
+
|
| 171 |
+
output = torch.einsum("ikmn,k->imn", [logit, bins]).unsqueeze(dim=1)
|
| 172 |
+
|
| 173 |
+
else:
|
| 174 |
+
if self.scale_up:
|
| 175 |
+
output = self.sigmoid(self.conv_depth(feat)) * self.max_depth
|
| 176 |
+
else:
|
| 177 |
+
output = self.relu(self.conv_depth(feat)) + self.min_depth
|
| 178 |
+
return output
|
| 179 |
+
|
| 180 |
+
def losses(self, depth_pred, depth_gt):
|
| 181 |
+
"""Compute depth loss."""
|
| 182 |
+
loss = dict()
|
| 183 |
+
depth_pred = resize(
|
| 184 |
+
input=depth_pred, size=depth_gt.shape[2:], mode="bilinear", align_corners=self.align_corners, warning=False
|
| 185 |
+
)
|
| 186 |
+
if not isinstance(self.loss_decode, nn.ModuleList):
|
| 187 |
+
losses_decode = [self.loss_decode]
|
| 188 |
+
else:
|
| 189 |
+
losses_decode = self.loss_decode
|
| 190 |
+
for loss_decode in losses_decode:
|
| 191 |
+
if loss_decode.loss_name not in loss:
|
| 192 |
+
loss[loss_decode.loss_name] = loss_decode(depth_pred, depth_gt)
|
| 193 |
+
else:
|
| 194 |
+
loss[loss_decode.loss_name] += loss_decode(depth_pred, depth_gt)
|
| 195 |
+
return loss
|
| 196 |
+
|
| 197 |
+
def log_images(self, img_path, depth_pred, depth_gt, img_meta):
|
| 198 |
+
import numpy as np
|
| 199 |
+
|
| 200 |
+
show_img = copy.deepcopy(img_path.detach().cpu().permute(1, 2, 0))
|
| 201 |
+
show_img = show_img.numpy().astype(np.float32)
|
| 202 |
+
show_img = _imdenormalize(
|
| 203 |
+
show_img,
|
| 204 |
+
img_meta["img_norm_cfg"]["mean"],
|
| 205 |
+
img_meta["img_norm_cfg"]["std"],
|
| 206 |
+
img_meta["img_norm_cfg"]["to_rgb"],
|
| 207 |
+
)
|
| 208 |
+
show_img = np.clip(show_img, 0, 255)
|
| 209 |
+
show_img = show_img.astype(np.uint8)
|
| 210 |
+
show_img = show_img[:, :, ::-1]
|
| 211 |
+
show_img = show_img.transpose(0, 2, 1)
|
| 212 |
+
show_img = show_img.transpose(1, 0, 2)
|
| 213 |
+
|
| 214 |
+
depth_pred = depth_pred / torch.max(depth_pred)
|
| 215 |
+
depth_gt = depth_gt / torch.max(depth_gt)
|
| 216 |
+
|
| 217 |
+
depth_pred_color = copy.deepcopy(depth_pred.detach().cpu())
|
| 218 |
+
depth_gt_color = copy.deepcopy(depth_gt.detach().cpu())
|
| 219 |
+
|
| 220 |
+
return {"img_rgb": show_img, "img_depth_pred": depth_pred_color, "img_depth_gt": depth_gt_color}
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
class BNHead(DepthBaseDecodeHead):
|
| 224 |
+
"""Just a batchnorm."""
|
| 225 |
+
|
| 226 |
+
def __init__(self, input_transform="resize_concat", in_index=(0, 1, 2, 3), upsample=1, **kwargs):
|
| 227 |
+
super().__init__(**kwargs)
|
| 228 |
+
self.input_transform = input_transform
|
| 229 |
+
self.in_index = in_index
|
| 230 |
+
self.upsample = upsample
|
| 231 |
+
# self.bn = nn.SyncBatchNorm(self.in_channels)
|
| 232 |
+
if self.classify:
|
| 233 |
+
self.conv_depth = nn.Conv2d(self.channels, self.n_bins, kernel_size=1, padding=0, stride=1)
|
| 234 |
+
else:
|
| 235 |
+
self.conv_depth = nn.Conv2d(self.channels, 1, kernel_size=1, padding=0, stride=1)
|
| 236 |
+
|
| 237 |
+
def _transform_inputs(self, inputs):
|
| 238 |
+
"""Transform inputs for decoder.
|
| 239 |
+
Args:
|
| 240 |
+
inputs (list[Tensor]): List of multi-level img features.
|
| 241 |
+
Returns:
|
| 242 |
+
Tensor: The transformed inputs
|
| 243 |
+
"""
|
| 244 |
+
|
| 245 |
+
if "concat" in self.input_transform:
|
| 246 |
+
inputs = [inputs[i] for i in self.in_index]
|
| 247 |
+
if "resize" in self.input_transform:
|
| 248 |
+
inputs = [
|
| 249 |
+
resize(
|
| 250 |
+
input=x,
|
| 251 |
+
size=[s * self.upsample for s in inputs[0].shape[2:]],
|
| 252 |
+
mode="bilinear",
|
| 253 |
+
align_corners=self.align_corners,
|
| 254 |
+
)
|
| 255 |
+
for x in inputs
|
| 256 |
+
]
|
| 257 |
+
inputs = torch.cat(inputs, dim=1)
|
| 258 |
+
elif self.input_transform == "multiple_select":
|
| 259 |
+
inputs = [inputs[i] for i in self.in_index]
|
| 260 |
+
else:
|
| 261 |
+
inputs = inputs[self.in_index]
|
| 262 |
+
|
| 263 |
+
return inputs
|
| 264 |
+
|
| 265 |
+
def _forward_feature(self, inputs, img_metas=None, **kwargs):
|
| 266 |
+
"""Forward function for feature maps before classifying each pixel with
|
| 267 |
+
``self.cls_seg`` fc.
|
| 268 |
+
Args:
|
| 269 |
+
inputs (list[Tensor]): List of multi-level img features.
|
| 270 |
+
Returns:
|
| 271 |
+
feats (Tensor): A tensor of shape (batch_size, self.channels,
|
| 272 |
+
H, W) which is feature map for last layer of decoder head.
|
| 273 |
+
"""
|
| 274 |
+
# accept lists (for cls token)
|
| 275 |
+
inputs = list(inputs)
|
| 276 |
+
for i, x in enumerate(inputs):
|
| 277 |
+
if len(x) == 2:
|
| 278 |
+
x, cls_token = x[0], x[1]
|
| 279 |
+
if len(x.shape) == 2:
|
| 280 |
+
x = x[:, :, None, None]
|
| 281 |
+
cls_token = cls_token[:, :, None, None].expand_as(x)
|
| 282 |
+
inputs[i] = torch.cat((x, cls_token), 1)
|
| 283 |
+
else:
|
| 284 |
+
x = x[0]
|
| 285 |
+
if len(x.shape) == 2:
|
| 286 |
+
x = x[:, :, None, None]
|
| 287 |
+
inputs[i] = x
|
| 288 |
+
x = self._transform_inputs(inputs)
|
| 289 |
+
# feats = self.bn(x)
|
| 290 |
+
return x
|
| 291 |
+
|
| 292 |
+
def forward(self, inputs, img_metas=None, **kwargs):
|
| 293 |
+
"""Forward function."""
|
| 294 |
+
output = self._forward_feature(inputs, img_metas=img_metas, **kwargs)
|
| 295 |
+
output = self.depth_pred(output)
|
| 296 |
+
return output
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
class ConvModule(nn.Module):
|
| 300 |
+
"""A conv block that bundles conv/norm/activation layers.
|
| 301 |
+
|
| 302 |
+
This block simplifies the usage of convolution layers, which are commonly
|
| 303 |
+
used with a norm layer (e.g., BatchNorm) and activation layer (e.g., ReLU).
|
| 304 |
+
It is based upon three build methods: `build_conv_layer()`,
|
| 305 |
+
`build_norm_layer()` and `build_activation_layer()`.
|
| 306 |
+
|
| 307 |
+
Besides, we add some additional features in this module.
|
| 308 |
+
1. Automatically set `bias` of the conv layer.
|
| 309 |
+
2. Spectral norm is supported.
|
| 310 |
+
3. More padding modes are supported. Before PyTorch 1.5, nn.Conv2d only
|
| 311 |
+
supports zero and circular padding, and we add "reflect" padding mode.
|
| 312 |
+
|
| 313 |
+
Args:
|
| 314 |
+
in_channels (int): Number of channels in the input feature map.
|
| 315 |
+
Same as that in ``nn._ConvNd``.
|
| 316 |
+
out_channels (int): Number of channels produced by the convolution.
|
| 317 |
+
Same as that in ``nn._ConvNd``.
|
| 318 |
+
kernel_size (int | tuple[int]): Size of the convolving kernel.
|
| 319 |
+
Same as that in ``nn._ConvNd``.
|
| 320 |
+
stride (int | tuple[int]): Stride of the convolution.
|
| 321 |
+
Same as that in ``nn._ConvNd``.
|
| 322 |
+
padding (int | tuple[int]): Zero-padding added to both sides of
|
| 323 |
+
the input. Same as that in ``nn._ConvNd``.
|
| 324 |
+
dilation (int | tuple[int]): Spacing between kernel elements.
|
| 325 |
+
Same as that in ``nn._ConvNd``.
|
| 326 |
+
groups (int): Number of blocked connections from input channels to
|
| 327 |
+
output channels. Same as that in ``nn._ConvNd``.
|
| 328 |
+
bias (bool | str): If specified as `auto`, it will be decided by the
|
| 329 |
+
norm_layer. Bias will be set as True if `norm_layer` is None, otherwise
|
| 330 |
+
False. Default: "auto".
|
| 331 |
+
conv_layer (nn.Module): Convolution layer. Default: None,
|
| 332 |
+
which means using conv2d.
|
| 333 |
+
norm_layer (nn.Module): Normalization layer. Default: None.
|
| 334 |
+
act_layer (nn.Module): Activation layer. Default: nn.ReLU.
|
| 335 |
+
inplace (bool): Whether to use inplace mode for activation.
|
| 336 |
+
Default: True.
|
| 337 |
+
with_spectral_norm (bool): Whether use spectral norm in conv module.
|
| 338 |
+
Default: False.
|
| 339 |
+
padding_mode (str): If the `padding_mode` has not been supported by
|
| 340 |
+
current `Conv2d` in PyTorch, we will use our own padding layer
|
| 341 |
+
instead. Currently, we support ['zeros', 'circular'] with official
|
| 342 |
+
implementation and ['reflect'] with our own implementation.
|
| 343 |
+
Default: 'zeros'.
|
| 344 |
+
order (tuple[str]): The order of conv/norm/activation layers. It is a
|
| 345 |
+
sequence of "conv", "norm" and "act". Common examples are
|
| 346 |
+
("conv", "norm", "act") and ("act", "conv", "norm").
|
| 347 |
+
Default: ('conv', 'norm', 'act').
|
| 348 |
+
"""
|
| 349 |
+
|
| 350 |
+
_abbr_ = "conv_block"
|
| 351 |
+
|
| 352 |
+
def __init__(
|
| 353 |
+
self,
|
| 354 |
+
in_channels,
|
| 355 |
+
out_channels,
|
| 356 |
+
kernel_size,
|
| 357 |
+
stride=1,
|
| 358 |
+
padding=0,
|
| 359 |
+
dilation=1,
|
| 360 |
+
groups=1,
|
| 361 |
+
bias="auto",
|
| 362 |
+
conv_layer=nn.Conv2d,
|
| 363 |
+
norm_layer=None,
|
| 364 |
+
act_layer=nn.ReLU,
|
| 365 |
+
inplace=True,
|
| 366 |
+
with_spectral_norm=False,
|
| 367 |
+
padding_mode="zeros",
|
| 368 |
+
order=("conv", "norm", "act"),
|
| 369 |
+
):
|
| 370 |
+
super(ConvModule, self).__init__()
|
| 371 |
+
official_padding_mode = ["zeros", "circular"]
|
| 372 |
+
self.conv_layer = conv_layer
|
| 373 |
+
self.norm_layer = norm_layer
|
| 374 |
+
self.act_layer = act_layer
|
| 375 |
+
self.inplace = inplace
|
| 376 |
+
self.with_spectral_norm = with_spectral_norm
|
| 377 |
+
self.with_explicit_padding = padding_mode not in official_padding_mode
|
| 378 |
+
self.order = order
|
| 379 |
+
assert isinstance(self.order, tuple) and len(self.order) == 3
|
| 380 |
+
assert set(order) == set(["conv", "norm", "act"])
|
| 381 |
+
|
| 382 |
+
self.with_norm = norm_layer is not None
|
| 383 |
+
self.with_activation = act_layer is not None
|
| 384 |
+
# if the conv layer is before a norm layer, bias is unnecessary.
|
| 385 |
+
if bias == "auto":
|
| 386 |
+
bias = not self.with_norm
|
| 387 |
+
self.with_bias = bias
|
| 388 |
+
|
| 389 |
+
if self.with_explicit_padding:
|
| 390 |
+
if padding_mode == "zeros":
|
| 391 |
+
padding_layer = nn.ZeroPad2d
|
| 392 |
+
else:
|
| 393 |
+
raise AssertionError(f"Unsupported padding mode: {padding_mode}")
|
| 394 |
+
self.pad = padding_layer(padding)
|
| 395 |
+
|
| 396 |
+
# reset padding to 0 for conv module
|
| 397 |
+
conv_padding = 0 if self.with_explicit_padding else padding
|
| 398 |
+
# build convolution layer
|
| 399 |
+
self.conv = self.conv_layer(
|
| 400 |
+
in_channels,
|
| 401 |
+
out_channels,
|
| 402 |
+
kernel_size,
|
| 403 |
+
stride=stride,
|
| 404 |
+
padding=conv_padding,
|
| 405 |
+
dilation=dilation,
|
| 406 |
+
groups=groups,
|
| 407 |
+
bias=bias,
|
| 408 |
+
)
|
| 409 |
+
# export the attributes of self.conv to a higher level for convenience
|
| 410 |
+
self.in_channels = self.conv.in_channels
|
| 411 |
+
self.out_channels = self.conv.out_channels
|
| 412 |
+
self.kernel_size = self.conv.kernel_size
|
| 413 |
+
self.stride = self.conv.stride
|
| 414 |
+
self.padding = padding
|
| 415 |
+
self.dilation = self.conv.dilation
|
| 416 |
+
self.transposed = self.conv.transposed
|
| 417 |
+
self.output_padding = self.conv.output_padding
|
| 418 |
+
self.groups = self.conv.groups
|
| 419 |
+
|
| 420 |
+
if self.with_spectral_norm:
|
| 421 |
+
self.conv = nn.utils.spectral_norm(self.conv)
|
| 422 |
+
|
| 423 |
+
# build normalization layers
|
| 424 |
+
if self.with_norm:
|
| 425 |
+
# norm layer is after conv layer
|
| 426 |
+
if order.index("norm") > order.index("conv"):
|
| 427 |
+
norm_channels = out_channels
|
| 428 |
+
else:
|
| 429 |
+
norm_channels = in_channels
|
| 430 |
+
norm = partial(norm_layer, num_features=norm_channels)
|
| 431 |
+
self.add_module("norm", norm)
|
| 432 |
+
if self.with_bias:
|
| 433 |
+
from torch.nnModules.batchnorm import _BatchNorm
|
| 434 |
+
from torch.nnModules.instancenorm import _InstanceNorm
|
| 435 |
+
|
| 436 |
+
if isinstance(norm, (_BatchNorm, _InstanceNorm)):
|
| 437 |
+
warnings.warn("Unnecessary conv bias before batch/instance norm")
|
| 438 |
+
else:
|
| 439 |
+
self.norm_name = None
|
| 440 |
+
|
| 441 |
+
# build activation layer
|
| 442 |
+
if self.with_activation:
|
| 443 |
+
# nn.Tanh has no 'inplace' argument
|
| 444 |
+
# (nn.Tanh, nn.PReLU, nn.Sigmoid, nn.HSigmoid, nn.Swish, nn.GELU)
|
| 445 |
+
if not isinstance(act_layer, (nn.Tanh, nn.PReLU, nn.Sigmoid, nn.GELU)):
|
| 446 |
+
act_layer = partial(act_layer, inplace=inplace)
|
| 447 |
+
self.activate = act_layer()
|
| 448 |
+
|
| 449 |
+
# Use msra init by default
|
| 450 |
+
self.init_weights()
|
| 451 |
+
|
| 452 |
+
@property
|
| 453 |
+
def norm(self):
|
| 454 |
+
if self.norm_name:
|
| 455 |
+
return getattr(self, self.norm_name)
|
| 456 |
+
else:
|
| 457 |
+
return None
|
| 458 |
+
|
| 459 |
+
def init_weights(self):
|
| 460 |
+
# 1. It is mainly for customized conv layers with their own
|
| 461 |
+
# initialization manners by calling their own ``init_weights()``,
|
| 462 |
+
# and we do not want ConvModule to override the initialization.
|
| 463 |
+
# 2. For customized conv layers without their own initialization
|
| 464 |
+
# manners (that is, they don't have their own ``init_weights()``)
|
| 465 |
+
# and PyTorch's conv layers, they will be initialized by
|
| 466 |
+
# this method with default ``kaiming_init``.
|
| 467 |
+
# Note: For PyTorch's conv layers, they will be overwritten by our
|
| 468 |
+
# initialization implementation using default ``kaiming_init``.
|
| 469 |
+
if not hasattr(self.conv, "init_weights"):
|
| 470 |
+
if self.with_activation and isinstance(self.act_layer, nn.LeakyReLU):
|
| 471 |
+
nonlinearity = "leaky_relu"
|
| 472 |
+
a = 0.01 # XXX: default negative_slope
|
| 473 |
+
else:
|
| 474 |
+
nonlinearity = "relu"
|
| 475 |
+
a = 0
|
| 476 |
+
if hasattr(self.conv, "weight") and self.conv.weight is not None:
|
| 477 |
+
nn.init.kaiming_normal_(self.conv.weight, a=a, mode="fan_out", nonlinearity=nonlinearity)
|
| 478 |
+
if hasattr(self.conv, "bias") and self.conv.bias is not None:
|
| 479 |
+
nn.init.constant_(self.conv.bias, 0)
|
| 480 |
+
if self.with_norm:
|
| 481 |
+
if hasattr(self.norm, "weight") and self.norm.weight is not None:
|
| 482 |
+
nn.init.constant_(self.norm.weight, 1)
|
| 483 |
+
if hasattr(self.norm, "bias") and self.norm.bias is not None:
|
| 484 |
+
nn.init.constant_(self.norm.bias, 0)
|
| 485 |
+
|
| 486 |
+
def forward(self, x, activate=True, norm=True):
|
| 487 |
+
for layer in self.order:
|
| 488 |
+
if layer == "conv":
|
| 489 |
+
if self.with_explicit_padding:
|
| 490 |
+
x = self.pad(x)
|
| 491 |
+
x = self.conv(x)
|
| 492 |
+
elif layer == "norm" and norm and self.with_norm:
|
| 493 |
+
x = self.norm(x)
|
| 494 |
+
elif layer == "act" and activate and self.with_activation:
|
| 495 |
+
x = self.activate(x)
|
| 496 |
+
return x
|
| 497 |
+
|
| 498 |
+
|
| 499 |
+
class Interpolate(nn.Module):
|
| 500 |
+
def __init__(self, scale_factor, mode, align_corners=False):
|
| 501 |
+
super(Interpolate, self).__init__()
|
| 502 |
+
self.interp = nn.functional.interpolate
|
| 503 |
+
self.scale_factor = scale_factor
|
| 504 |
+
self.mode = mode
|
| 505 |
+
self.align_corners = align_corners
|
| 506 |
+
|
| 507 |
+
def forward(self, x):
|
| 508 |
+
x = self.interp(x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners)
|
| 509 |
+
return x
|
| 510 |
+
|
| 511 |
+
|
| 512 |
+
class HeadDepth(nn.Module):
|
| 513 |
+
def __init__(self, features):
|
| 514 |
+
super(HeadDepth, self).__init__()
|
| 515 |
+
self.head = nn.Sequential(
|
| 516 |
+
nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
|
| 517 |
+
Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
|
| 518 |
+
nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
|
| 519 |
+
nn.ReLU(),
|
| 520 |
+
nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
|
| 521 |
+
)
|
| 522 |
+
|
| 523 |
+
def forward(self, x):
|
| 524 |
+
x = self.head(x)
|
| 525 |
+
return x
|
| 526 |
+
|
| 527 |
+
|
| 528 |
+
class ReassembleBlocks(nn.Module):
|
| 529 |
+
"""ViTPostProcessBlock, process cls_token in ViT backbone output and
|
| 530 |
+
rearrange the feature vector to feature map.
|
| 531 |
+
Args:
|
| 532 |
+
in_channels (int): ViT feature channels. Default: 768.
|
| 533 |
+
out_channels (List): output channels of each stage.
|
| 534 |
+
Default: [96, 192, 384, 768].
|
| 535 |
+
readout_type (str): Type of readout operation. Default: 'ignore'.
|
| 536 |
+
patch_size (int): The patch size. Default: 16.
|
| 537 |
+
"""
|
| 538 |
+
|
| 539 |
+
def __init__(self, in_channels=768, out_channels=[96, 192, 384, 768], readout_type="ignore", patch_size=16):
|
| 540 |
+
super(ReassembleBlocks, self).__init__()
|
| 541 |
+
|
| 542 |
+
assert readout_type in ["ignore", "add", "project"]
|
| 543 |
+
self.readout_type = readout_type
|
| 544 |
+
self.patch_size = patch_size
|
| 545 |
+
|
| 546 |
+
self.projects = nn.ModuleList(
|
| 547 |
+
[
|
| 548 |
+
ConvModule(
|
| 549 |
+
in_channels=in_channels,
|
| 550 |
+
out_channels=out_channel,
|
| 551 |
+
kernel_size=1,
|
| 552 |
+
act_layer=None,
|
| 553 |
+
)
|
| 554 |
+
for out_channel in out_channels
|
| 555 |
+
]
|
| 556 |
+
)
|
| 557 |
+
|
| 558 |
+
self.resize_layers = nn.ModuleList(
|
| 559 |
+
[
|
| 560 |
+
nn.ConvTranspose2d(
|
| 561 |
+
in_channels=out_channels[0], out_channels=out_channels[0], kernel_size=4, stride=4, padding=0
|
| 562 |
+
),
|
| 563 |
+
nn.ConvTranspose2d(
|
| 564 |
+
in_channels=out_channels[1], out_channels=out_channels[1], kernel_size=2, stride=2, padding=0
|
| 565 |
+
),
|
| 566 |
+
nn.Identity(),
|
| 567 |
+
nn.Conv2d(
|
| 568 |
+
in_channels=out_channels[3], out_channels=out_channels[3], kernel_size=3, stride=2, padding=1
|
| 569 |
+
),
|
| 570 |
+
]
|
| 571 |
+
)
|
| 572 |
+
if self.readout_type == "project":
|
| 573 |
+
self.readout_projects = nn.ModuleList()
|
| 574 |
+
for _ in range(len(self.projects)):
|
| 575 |
+
self.readout_projects.append(nn.Sequential(nn.Linear(2 * in_channels, in_channels), nn.GELU()))
|
| 576 |
+
|
| 577 |
+
def forward(self, inputs):
|
| 578 |
+
assert isinstance(inputs, list)
|
| 579 |
+
out = []
|
| 580 |
+
for i, x in enumerate(inputs):
|
| 581 |
+
assert len(x) == 2
|
| 582 |
+
x, cls_token = x[0], x[1]
|
| 583 |
+
feature_shape = x.shape
|
| 584 |
+
if self.readout_type == "project":
|
| 585 |
+
x = x.flatten(2).permute((0, 2, 1))
|
| 586 |
+
readout = cls_token.unsqueeze(1).expand_as(x)
|
| 587 |
+
x = self.readout_projects[i](torch.cat((x, readout), -1))
|
| 588 |
+
x = x.permute(0, 2, 1).reshape(feature_shape)
|
| 589 |
+
elif self.readout_type == "add":
|
| 590 |
+
x = x.flatten(2) + cls_token.unsqueeze(-1)
|
| 591 |
+
x = x.reshape(feature_shape)
|
| 592 |
+
else:
|
| 593 |
+
pass
|
| 594 |
+
x = self.projects[i](x)
|
| 595 |
+
x = self.resize_layers[i](x)
|
| 596 |
+
out.append(x)
|
| 597 |
+
return out
|
| 598 |
+
|
| 599 |
+
|
| 600 |
+
class PreActResidualConvUnit(nn.Module):
|
| 601 |
+
"""ResidualConvUnit, pre-activate residual unit.
|
| 602 |
+
Args:
|
| 603 |
+
in_channels (int): number of channels in the input feature map.
|
| 604 |
+
act_layer (nn.Module): activation layer.
|
| 605 |
+
norm_layer (nn.Module): norm layer.
|
| 606 |
+
stride (int): stride of the first block. Default: 1
|
| 607 |
+
dilation (int): dilation rate for convs layers. Default: 1.
|
| 608 |
+
"""
|
| 609 |
+
|
| 610 |
+
def __init__(self, in_channels, act_layer, norm_layer, stride=1, dilation=1):
|
| 611 |
+
super(PreActResidualConvUnit, self).__init__()
|
| 612 |
+
|
| 613 |
+
self.conv1 = ConvModule(
|
| 614 |
+
in_channels,
|
| 615 |
+
in_channels,
|
| 616 |
+
3,
|
| 617 |
+
stride=stride,
|
| 618 |
+
padding=dilation,
|
| 619 |
+
dilation=dilation,
|
| 620 |
+
norm_layer=norm_layer,
|
| 621 |
+
act_layer=act_layer,
|
| 622 |
+
bias=False,
|
| 623 |
+
order=("act", "conv", "norm"),
|
| 624 |
+
)
|
| 625 |
+
|
| 626 |
+
self.conv2 = ConvModule(
|
| 627 |
+
in_channels,
|
| 628 |
+
in_channels,
|
| 629 |
+
3,
|
| 630 |
+
padding=1,
|
| 631 |
+
norm_layer=norm_layer,
|
| 632 |
+
act_layer=act_layer,
|
| 633 |
+
bias=False,
|
| 634 |
+
order=("act", "conv", "norm"),
|
| 635 |
+
)
|
| 636 |
+
|
| 637 |
+
def forward(self, inputs):
|
| 638 |
+
inputs_ = inputs.clone()
|
| 639 |
+
x = self.conv1(inputs)
|
| 640 |
+
x = self.conv2(x)
|
| 641 |
+
return x + inputs_
|
| 642 |
+
|
| 643 |
+
|
| 644 |
+
class FeatureFusionBlock(nn.Module):
|
| 645 |
+
"""FeatureFusionBlock, merge feature map from different stages.
|
| 646 |
+
Args:
|
| 647 |
+
in_channels (int): Input channels.
|
| 648 |
+
act_layer (nn.Module): activation layer for ResidualConvUnit.
|
| 649 |
+
norm_layer (nn.Module): normalization layer.
|
| 650 |
+
expand (bool): Whether expand the channels in post process block.
|
| 651 |
+
Default: False.
|
| 652 |
+
align_corners (bool): align_corner setting for bilinear upsample.
|
| 653 |
+
Default: True.
|
| 654 |
+
"""
|
| 655 |
+
|
| 656 |
+
def __init__(self, in_channels, act_layer, norm_layer, expand=False, align_corners=True):
|
| 657 |
+
super(FeatureFusionBlock, self).__init__()
|
| 658 |
+
|
| 659 |
+
self.in_channels = in_channels
|
| 660 |
+
self.expand = expand
|
| 661 |
+
self.align_corners = align_corners
|
| 662 |
+
|
| 663 |
+
self.out_channels = in_channels
|
| 664 |
+
if self.expand:
|
| 665 |
+
self.out_channels = in_channels // 2
|
| 666 |
+
|
| 667 |
+
self.project = ConvModule(self.in_channels, self.out_channels, kernel_size=1, act_layer=None, bias=True)
|
| 668 |
+
|
| 669 |
+
self.res_conv_unit1 = PreActResidualConvUnit(
|
| 670 |
+
in_channels=self.in_channels, act_layer=act_layer, norm_layer=norm_layer
|
| 671 |
+
)
|
| 672 |
+
self.res_conv_unit2 = PreActResidualConvUnit(
|
| 673 |
+
in_channels=self.in_channels, act_layer=act_layer, norm_layer=norm_layer
|
| 674 |
+
)
|
| 675 |
+
|
| 676 |
+
def forward(self, *inputs):
|
| 677 |
+
x = inputs[0]
|
| 678 |
+
if len(inputs) == 2:
|
| 679 |
+
if x.shape != inputs[1].shape:
|
| 680 |
+
res = resize(inputs[1], size=(x.shape[2], x.shape[3]), mode="bilinear", align_corners=False)
|
| 681 |
+
else:
|
| 682 |
+
res = inputs[1]
|
| 683 |
+
x = x + self.res_conv_unit1(res)
|
| 684 |
+
x = self.res_conv_unit2(x)
|
| 685 |
+
x = resize(x, scale_factor=2, mode="bilinear", align_corners=self.align_corners)
|
| 686 |
+
x = self.project(x)
|
| 687 |
+
return x
|
| 688 |
+
|
| 689 |
+
|
| 690 |
+
class DPTHead(DepthBaseDecodeHead):
|
| 691 |
+
"""Vision Transformers for Dense Prediction.
|
| 692 |
+
This head is implemented of `DPT <https://arxiv.org/abs/2103.13413>`_.
|
| 693 |
+
Args:
|
| 694 |
+
embed_dims (int): The embed dimension of the ViT backbone.
|
| 695 |
+
Default: 768.
|
| 696 |
+
post_process_channels (List): Out channels of post process conv
|
| 697 |
+
layers. Default: [96, 192, 384, 768].
|
| 698 |
+
readout_type (str): Type of readout operation. Default: 'ignore'.
|
| 699 |
+
patch_size (int): The patch size. Default: 16.
|
| 700 |
+
expand_channels (bool): Whether expand the channels in post process
|
| 701 |
+
block. Default: False.
|
| 702 |
+
"""
|
| 703 |
+
|
| 704 |
+
def __init__(
|
| 705 |
+
self,
|
| 706 |
+
embed_dims=768,
|
| 707 |
+
post_process_channels=[96, 192, 384, 768],
|
| 708 |
+
readout_type="ignore",
|
| 709 |
+
patch_size=16,
|
| 710 |
+
expand_channels=False,
|
| 711 |
+
**kwargs,
|
| 712 |
+
):
|
| 713 |
+
super(DPTHead, self).__init__(**kwargs)
|
| 714 |
+
|
| 715 |
+
self.in_channels = self.in_channels
|
| 716 |
+
self.expand_channels = expand_channels
|
| 717 |
+
self.reassemble_blocks = ReassembleBlocks(embed_dims, post_process_channels, readout_type, patch_size)
|
| 718 |
+
|
| 719 |
+
self.post_process_channels = [
|
| 720 |
+
channel * math.pow(2, i) if expand_channels else channel for i, channel in enumerate(post_process_channels)
|
| 721 |
+
]
|
| 722 |
+
self.convs = nn.ModuleList()
|
| 723 |
+
for channel in self.post_process_channels:
|
| 724 |
+
self.convs.append(ConvModule(channel, self.channels, kernel_size=3, padding=1, act_layer=None, bias=False))
|
| 725 |
+
self.fusion_blocks = nn.ModuleList()
|
| 726 |
+
for _ in range(len(self.convs)):
|
| 727 |
+
self.fusion_blocks.append(FeatureFusionBlock(self.channels, self.act_layer, self.norm_layer))
|
| 728 |
+
self.fusion_blocks[0].res_conv_unit1 = None
|
| 729 |
+
self.project = ConvModule(self.channels, self.channels, kernel_size=3, padding=1, norm_layer=self.norm_layer)
|
| 730 |
+
self.num_fusion_blocks = len(self.fusion_blocks)
|
| 731 |
+
self.num_reassemble_blocks = len(self.reassemble_blocks.resize_layers)
|
| 732 |
+
self.num_post_process_channels = len(self.post_process_channels)
|
| 733 |
+
assert self.num_fusion_blocks == self.num_reassemble_blocks
|
| 734 |
+
assert self.num_reassemble_blocks == self.num_post_process_channels
|
| 735 |
+
self.conv_depth = HeadDepth(self.channels)
|
| 736 |
+
|
| 737 |
+
def forward(self, inputs, img_metas):
|
| 738 |
+
assert len(inputs) == self.num_reassemble_blocks
|
| 739 |
+
x = [inp for inp in inputs]
|
| 740 |
+
x = self.reassemble_blocks(x)
|
| 741 |
+
x = [self.convs[i](feature) for i, feature in enumerate(x)]
|
| 742 |
+
out = self.fusion_blocks[0](x[-1])
|
| 743 |
+
for i in range(1, len(self.fusion_blocks)):
|
| 744 |
+
out = self.fusion_blocks[i](out, x[-(i + 1)])
|
| 745 |
+
out = self.project(out)
|
| 746 |
+
out = self.depth_pred(out)
|
| 747 |
+
return out
|
torch_hub/facebookresearch_dinov2_main/dinov2/hub/depth/encoder_decoder.py
ADDED
|
@@ -0,0 +1,351 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from collections import OrderedDict
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
|
| 12 |
+
from .ops import resize
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def add_prefix(inputs, prefix):
|
| 16 |
+
"""Add prefix for dict.
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
inputs (dict): The input dict with str keys.
|
| 20 |
+
prefix (str): The prefix to add.
|
| 21 |
+
|
| 22 |
+
Returns:
|
| 23 |
+
|
| 24 |
+
dict: The dict with keys updated with ``prefix``.
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
outputs = dict()
|
| 28 |
+
for name, value in inputs.items():
|
| 29 |
+
outputs[f"{prefix}.{name}"] = value
|
| 30 |
+
|
| 31 |
+
return outputs
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class DepthEncoderDecoder(nn.Module):
|
| 35 |
+
"""Encoder Decoder depther.
|
| 36 |
+
|
| 37 |
+
EncoderDecoder typically consists of backbone and decode_head.
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
def __init__(self, backbone, decode_head):
|
| 41 |
+
super(DepthEncoderDecoder, self).__init__()
|
| 42 |
+
|
| 43 |
+
self.backbone = backbone
|
| 44 |
+
self.decode_head = decode_head
|
| 45 |
+
self.align_corners = self.decode_head.align_corners
|
| 46 |
+
|
| 47 |
+
def extract_feat(self, img):
|
| 48 |
+
"""Extract features from images."""
|
| 49 |
+
return self.backbone(img)
|
| 50 |
+
|
| 51 |
+
def encode_decode(self, img, img_metas, rescale=True, size=None):
|
| 52 |
+
"""Encode images with backbone and decode into a depth estimation
|
| 53 |
+
map of the same size as input."""
|
| 54 |
+
x = self.extract_feat(img)
|
| 55 |
+
out = self._decode_head_forward_test(x, img_metas)
|
| 56 |
+
# crop the pred depth to the certain range.
|
| 57 |
+
out = torch.clamp(out, min=self.decode_head.min_depth, max=self.decode_head.max_depth)
|
| 58 |
+
if rescale:
|
| 59 |
+
if size is None:
|
| 60 |
+
if img_metas is not None:
|
| 61 |
+
size = img_metas[0]["ori_shape"][:2]
|
| 62 |
+
else:
|
| 63 |
+
size = img.shape[2:]
|
| 64 |
+
out = resize(input=out, size=size, mode="bilinear", align_corners=self.align_corners)
|
| 65 |
+
return out
|
| 66 |
+
|
| 67 |
+
def _decode_head_forward_train(self, img, x, img_metas, depth_gt, **kwargs):
|
| 68 |
+
"""Run forward function and calculate loss for decode head in
|
| 69 |
+
training."""
|
| 70 |
+
losses = dict()
|
| 71 |
+
loss_decode = self.decode_head.forward_train(img, x, img_metas, depth_gt, **kwargs)
|
| 72 |
+
losses.update(add_prefix(loss_decode, "decode"))
|
| 73 |
+
return losses
|
| 74 |
+
|
| 75 |
+
def _decode_head_forward_test(self, x, img_metas):
|
| 76 |
+
"""Run forward function and calculate loss for decode head in
|
| 77 |
+
inference."""
|
| 78 |
+
depth_pred = self.decode_head.forward_test(x, img_metas)
|
| 79 |
+
return depth_pred
|
| 80 |
+
|
| 81 |
+
def forward_dummy(self, img):
|
| 82 |
+
"""Dummy forward function."""
|
| 83 |
+
depth = self.encode_decode(img, None)
|
| 84 |
+
|
| 85 |
+
return depth
|
| 86 |
+
|
| 87 |
+
def forward_train(self, img, img_metas, depth_gt, **kwargs):
|
| 88 |
+
"""Forward function for training.
|
| 89 |
+
|
| 90 |
+
Args:
|
| 91 |
+
img (Tensor): Input images.
|
| 92 |
+
img_metas (list[dict]): List of image info dict where each dict
|
| 93 |
+
has: 'img_shape', 'scale_factor', 'flip', and may also contain
|
| 94 |
+
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
|
| 95 |
+
For details on the values of these keys see
|
| 96 |
+
`depth/datasets/pipelines/formatting.py:Collect`.
|
| 97 |
+
depth_gt (Tensor): Depth gt
|
| 98 |
+
used if the architecture supports depth estimation task.
|
| 99 |
+
|
| 100 |
+
Returns:
|
| 101 |
+
dict[str, Tensor]: a dictionary of loss components
|
| 102 |
+
"""
|
| 103 |
+
|
| 104 |
+
x = self.extract_feat(img)
|
| 105 |
+
|
| 106 |
+
losses = dict()
|
| 107 |
+
|
| 108 |
+
# the last of x saves the info from neck
|
| 109 |
+
loss_decode = self._decode_head_forward_train(img, x, img_metas, depth_gt, **kwargs)
|
| 110 |
+
|
| 111 |
+
losses.update(loss_decode)
|
| 112 |
+
|
| 113 |
+
return losses
|
| 114 |
+
|
| 115 |
+
def whole_inference(self, img, img_meta, rescale, size=None):
|
| 116 |
+
"""Inference with full image."""
|
| 117 |
+
return self.encode_decode(img, img_meta, rescale, size=size)
|
| 118 |
+
|
| 119 |
+
def slide_inference(self, img, img_meta, rescale, stride, crop_size):
|
| 120 |
+
"""Inference by sliding-window with overlap.
|
| 121 |
+
|
| 122 |
+
If h_crop > h_img or w_crop > w_img, the small patch will be used to
|
| 123 |
+
decode without padding.
|
| 124 |
+
"""
|
| 125 |
+
|
| 126 |
+
h_stride, w_stride = stride
|
| 127 |
+
h_crop, w_crop = crop_size
|
| 128 |
+
batch_size, _, h_img, w_img = img.size()
|
| 129 |
+
h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1
|
| 130 |
+
w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1
|
| 131 |
+
preds = img.new_zeros((batch_size, 1, h_img, w_img))
|
| 132 |
+
count_mat = img.new_zeros((batch_size, 1, h_img, w_img))
|
| 133 |
+
for h_idx in range(h_grids):
|
| 134 |
+
for w_idx in range(w_grids):
|
| 135 |
+
y1 = h_idx * h_stride
|
| 136 |
+
x1 = w_idx * w_stride
|
| 137 |
+
y2 = min(y1 + h_crop, h_img)
|
| 138 |
+
x2 = min(x1 + w_crop, w_img)
|
| 139 |
+
y1 = max(y2 - h_crop, 0)
|
| 140 |
+
x1 = max(x2 - w_crop, 0)
|
| 141 |
+
crop_img = img[:, :, y1:y2, x1:x2]
|
| 142 |
+
depth_pred = self.encode_decode(crop_img, img_meta, rescale)
|
| 143 |
+
preds += F.pad(depth_pred, (int(x1), int(preds.shape[3] - x2), int(y1), int(preds.shape[2] - y2)))
|
| 144 |
+
|
| 145 |
+
count_mat[:, :, y1:y2, x1:x2] += 1
|
| 146 |
+
assert (count_mat == 0).sum() == 0
|
| 147 |
+
if torch.onnx.is_in_onnx_export():
|
| 148 |
+
# cast count_mat to constant while exporting to ONNX
|
| 149 |
+
count_mat = torch.from_numpy(count_mat.cpu().detach().numpy()).to(device=img.device)
|
| 150 |
+
preds = preds / count_mat
|
| 151 |
+
return preds
|
| 152 |
+
|
| 153 |
+
def inference(self, img, img_meta, rescale, size=None, mode="whole"):
|
| 154 |
+
"""Inference with slide/whole style.
|
| 155 |
+
|
| 156 |
+
Args:
|
| 157 |
+
img (Tensor): The input image of shape (N, 3, H, W).
|
| 158 |
+
img_meta (dict): Image info dict where each dict has: 'img_shape',
|
| 159 |
+
'scale_factor', 'flip', and may also contain
|
| 160 |
+
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
|
| 161 |
+
For details on the values of these keys see
|
| 162 |
+
`depth/datasets/pipelines/formatting.py:Collect`.
|
| 163 |
+
rescale (bool): Whether rescale back to original shape.
|
| 164 |
+
|
| 165 |
+
Returns:
|
| 166 |
+
Tensor: The output depth map.
|
| 167 |
+
"""
|
| 168 |
+
|
| 169 |
+
assert mode in ["slide", "whole"]
|
| 170 |
+
ori_shape = img_meta[0]["ori_shape"]
|
| 171 |
+
assert all(_["ori_shape"] == ori_shape for _ in img_meta)
|
| 172 |
+
if mode == "slide":
|
| 173 |
+
depth_pred = self.slide_inference(img, img_meta, rescale)
|
| 174 |
+
else:
|
| 175 |
+
depth_pred = self.whole_inference(img, img_meta, rescale, size=size)
|
| 176 |
+
output = depth_pred
|
| 177 |
+
flip = img_meta[0]["flip"]
|
| 178 |
+
if flip:
|
| 179 |
+
flip_direction = img_meta[0]["flip_direction"]
|
| 180 |
+
assert flip_direction in ["horizontal", "vertical"]
|
| 181 |
+
if flip_direction == "horizontal":
|
| 182 |
+
output = output.flip(dims=(3,))
|
| 183 |
+
elif flip_direction == "vertical":
|
| 184 |
+
output = output.flip(dims=(2,))
|
| 185 |
+
|
| 186 |
+
return output
|
| 187 |
+
|
| 188 |
+
def simple_test(self, img, img_meta, rescale=True):
|
| 189 |
+
"""Simple test with single image."""
|
| 190 |
+
depth_pred = self.inference(img, img_meta, rescale)
|
| 191 |
+
if torch.onnx.is_in_onnx_export():
|
| 192 |
+
# our inference backend only support 4D output
|
| 193 |
+
depth_pred = depth_pred.unsqueeze(0)
|
| 194 |
+
return depth_pred
|
| 195 |
+
depth_pred = depth_pred.cpu().numpy()
|
| 196 |
+
# unravel batch dim
|
| 197 |
+
depth_pred = list(depth_pred)
|
| 198 |
+
return depth_pred
|
| 199 |
+
|
| 200 |
+
def aug_test(self, imgs, img_metas, rescale=True):
|
| 201 |
+
"""Test with augmentations.
|
| 202 |
+
|
| 203 |
+
Only rescale=True is supported.
|
| 204 |
+
"""
|
| 205 |
+
# aug_test rescale all imgs back to ori_shape for now
|
| 206 |
+
assert rescale
|
| 207 |
+
# to save memory, we get augmented depth logit inplace
|
| 208 |
+
depth_pred = self.inference(imgs[0], img_metas[0], rescale)
|
| 209 |
+
for i in range(1, len(imgs)):
|
| 210 |
+
cur_depth_pred = self.inference(imgs[i], img_metas[i], rescale, size=depth_pred.shape[-2:])
|
| 211 |
+
depth_pred += cur_depth_pred
|
| 212 |
+
depth_pred /= len(imgs)
|
| 213 |
+
depth_pred = depth_pred.cpu().numpy()
|
| 214 |
+
# unravel batch dim
|
| 215 |
+
depth_pred = list(depth_pred)
|
| 216 |
+
return depth_pred
|
| 217 |
+
|
| 218 |
+
def forward_test(self, imgs, img_metas, **kwargs):
|
| 219 |
+
"""
|
| 220 |
+
Args:
|
| 221 |
+
imgs (List[Tensor]): the outer list indicates test-time
|
| 222 |
+
augmentations and inner Tensor should have a shape NxCxHxW,
|
| 223 |
+
which contains all images in the batch.
|
| 224 |
+
img_metas (List[List[dict]]): the outer list indicates test-time
|
| 225 |
+
augs (multiscale, flip, etc.) and the inner list indicates
|
| 226 |
+
images in a batch.
|
| 227 |
+
"""
|
| 228 |
+
for var, name in [(imgs, "imgs"), (img_metas, "img_metas")]:
|
| 229 |
+
if not isinstance(var, list):
|
| 230 |
+
raise TypeError(f"{name} must be a list, but got " f"{type(var)}")
|
| 231 |
+
num_augs = len(imgs)
|
| 232 |
+
if num_augs != len(img_metas):
|
| 233 |
+
raise ValueError(f"num of augmentations ({len(imgs)}) != " f"num of image meta ({len(img_metas)})")
|
| 234 |
+
# all images in the same aug batch all of the same ori_shape and pad
|
| 235 |
+
# shape
|
| 236 |
+
for img_meta in img_metas:
|
| 237 |
+
ori_shapes = [_["ori_shape"] for _ in img_meta]
|
| 238 |
+
assert all(shape == ori_shapes[0] for shape in ori_shapes)
|
| 239 |
+
img_shapes = [_["img_shape"] for _ in img_meta]
|
| 240 |
+
assert all(shape == img_shapes[0] for shape in img_shapes)
|
| 241 |
+
pad_shapes = [_["pad_shape"] for _ in img_meta]
|
| 242 |
+
assert all(shape == pad_shapes[0] for shape in pad_shapes)
|
| 243 |
+
|
| 244 |
+
if num_augs == 1:
|
| 245 |
+
return self.simple_test(imgs[0], img_metas[0], **kwargs)
|
| 246 |
+
else:
|
| 247 |
+
return self.aug_test(imgs, img_metas, **kwargs)
|
| 248 |
+
|
| 249 |
+
def forward(self, img, img_metas, return_loss=True, **kwargs):
|
| 250 |
+
"""Calls either :func:`forward_train` or :func:`forward_test` depending
|
| 251 |
+
on whether ``return_loss`` is ``True``.
|
| 252 |
+
|
| 253 |
+
Note this setting will change the expected inputs. When
|
| 254 |
+
``return_loss=True``, img and img_meta are single-nested (i.e. Tensor
|
| 255 |
+
and List[dict]), and when ``resturn_loss=False``, img and img_meta
|
| 256 |
+
should be double nested (i.e. List[Tensor], List[List[dict]]), with
|
| 257 |
+
the outer list indicating test time augmentations.
|
| 258 |
+
"""
|
| 259 |
+
if return_loss:
|
| 260 |
+
return self.forward_train(img, img_metas, **kwargs)
|
| 261 |
+
else:
|
| 262 |
+
return self.forward_test(img, img_metas, **kwargs)
|
| 263 |
+
|
| 264 |
+
def train_step(self, data_batch, optimizer, **kwargs):
|
| 265 |
+
"""The iteration step during training.
|
| 266 |
+
|
| 267 |
+
This method defines an iteration step during training, except for the
|
| 268 |
+
back propagation and optimizer updating, which are done in an optimizer
|
| 269 |
+
hook. Note that in some complicated cases or models, the whole process
|
| 270 |
+
including back propagation and optimizer updating is also defined in
|
| 271 |
+
this method, such as GAN.
|
| 272 |
+
|
| 273 |
+
Args:
|
| 274 |
+
data (dict): The output of dataloader.
|
| 275 |
+
optimizer (:obj:`torch.optim.Optimizer` | dict): The optimizer of
|
| 276 |
+
runner is passed to ``train_step()``. This argument is unused
|
| 277 |
+
and reserved.
|
| 278 |
+
|
| 279 |
+
Returns:
|
| 280 |
+
dict: It should contain at least 3 keys: ``loss``, ``log_vars``,
|
| 281 |
+
``num_samples``.
|
| 282 |
+
``loss`` is a tensor for back propagation, which can be a
|
| 283 |
+
weighted sum of multiple losses.
|
| 284 |
+
``log_vars`` contains all the variables to be sent to the
|
| 285 |
+
logger.
|
| 286 |
+
``num_samples`` indicates the batch size (when the model is
|
| 287 |
+
DDP, it means the batch size on each GPU), which is used for
|
| 288 |
+
averaging the logs.
|
| 289 |
+
"""
|
| 290 |
+
losses = self(**data_batch)
|
| 291 |
+
|
| 292 |
+
# split losses and images
|
| 293 |
+
real_losses = {}
|
| 294 |
+
log_imgs = {}
|
| 295 |
+
for k, v in losses.items():
|
| 296 |
+
if "img" in k:
|
| 297 |
+
log_imgs[k] = v
|
| 298 |
+
else:
|
| 299 |
+
real_losses[k] = v
|
| 300 |
+
|
| 301 |
+
loss, log_vars = self._parse_losses(real_losses)
|
| 302 |
+
|
| 303 |
+
outputs = dict(loss=loss, log_vars=log_vars, num_samples=len(data_batch["img_metas"]), log_imgs=log_imgs)
|
| 304 |
+
|
| 305 |
+
return outputs
|
| 306 |
+
|
| 307 |
+
def val_step(self, data_batch, **kwargs):
|
| 308 |
+
"""The iteration step during validation.
|
| 309 |
+
|
| 310 |
+
This method shares the same signature as :func:`train_step`, but used
|
| 311 |
+
during val epochs. Note that the evaluation after training epochs is
|
| 312 |
+
not implemented with this method, but an evaluation hook.
|
| 313 |
+
"""
|
| 314 |
+
output = self(**data_batch, **kwargs)
|
| 315 |
+
return output
|
| 316 |
+
|
| 317 |
+
@staticmethod
|
| 318 |
+
def _parse_losses(losses):
|
| 319 |
+
import torch.distributed as dist
|
| 320 |
+
|
| 321 |
+
"""Parse the raw outputs (losses) of the network.
|
| 322 |
+
|
| 323 |
+
Args:
|
| 324 |
+
losses (dict): Raw output of the network, which usually contain
|
| 325 |
+
losses and other necessary information.
|
| 326 |
+
|
| 327 |
+
Returns:
|
| 328 |
+
tuple[Tensor, dict]: (loss, log_vars), loss is the loss tensor
|
| 329 |
+
which may be a weighted sum of all losses, log_vars contains
|
| 330 |
+
all the variables to be sent to the logger.
|
| 331 |
+
"""
|
| 332 |
+
log_vars = OrderedDict()
|
| 333 |
+
for loss_name, loss_value in losses.items():
|
| 334 |
+
if isinstance(loss_value, torch.Tensor):
|
| 335 |
+
log_vars[loss_name] = loss_value.mean()
|
| 336 |
+
elif isinstance(loss_value, list):
|
| 337 |
+
log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value)
|
| 338 |
+
else:
|
| 339 |
+
raise TypeError(f"{loss_name} is not a tensor or list of tensors")
|
| 340 |
+
|
| 341 |
+
loss = sum(_value for _key, _value in log_vars.items() if "loss" in _key)
|
| 342 |
+
|
| 343 |
+
log_vars["loss"] = loss
|
| 344 |
+
for loss_name, loss_value in log_vars.items():
|
| 345 |
+
# reduce loss when distributed training
|
| 346 |
+
if dist.is_available() and dist.is_initialized():
|
| 347 |
+
loss_value = loss_value.data.clone()
|
| 348 |
+
dist.all_reduce(loss_value.div_(dist.get_world_size()))
|
| 349 |
+
log_vars[loss_name] = loss_value.item()
|
| 350 |
+
|
| 351 |
+
return loss, log_vars
|
torch_hub/facebookresearch_dinov2_main/dinov2/hub/depth/ops.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import warnings
|
| 7 |
+
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def resize(input, size=None, scale_factor=None, mode="nearest", align_corners=None, warning=False):
|
| 12 |
+
if warning:
|
| 13 |
+
if size is not None and align_corners:
|
| 14 |
+
input_h, input_w = tuple(int(x) for x in input.shape[2:])
|
| 15 |
+
output_h, output_w = tuple(int(x) for x in size)
|
| 16 |
+
if output_h > input_h or output_w > output_h:
|
| 17 |
+
if (
|
| 18 |
+
(output_h > 1 and output_w > 1 and input_h > 1 and input_w > 1)
|
| 19 |
+
and (output_h - 1) % (input_h - 1)
|
| 20 |
+
and (output_w - 1) % (input_w - 1)
|
| 21 |
+
):
|
| 22 |
+
warnings.warn(
|
| 23 |
+
f"When align_corners={align_corners}, "
|
| 24 |
+
"the output would more aligned if "
|
| 25 |
+
f"input size {(input_h, input_w)} is `x+1` and "
|
| 26 |
+
f"out size {(output_h, output_w)} is `nx+1`"
|
| 27 |
+
)
|
| 28 |
+
return F.interpolate(input, size, scale_factor, mode, align_corners)
|
torch_hub/facebookresearch_dinov2_main/dinov2/hub/utils.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import itertools
|
| 7 |
+
import math
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
_DINOV2_BASE_URL = "https://dl.fbaipublicfiles.com/dinov2"
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def _make_dinov2_model_name(arch_name: str, patch_size: int, num_register_tokens: int = 0) -> str:
|
| 18 |
+
compact_arch_name = arch_name.replace("_", "")[:4]
|
| 19 |
+
registers_suffix = f"_reg{num_register_tokens}" if num_register_tokens else ""
|
| 20 |
+
return f"dinov2_{compact_arch_name}{patch_size}{registers_suffix}"
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class CenterPadding(nn.Module):
|
| 24 |
+
def __init__(self, multiple):
|
| 25 |
+
super().__init__()
|
| 26 |
+
self.multiple = multiple
|
| 27 |
+
|
| 28 |
+
def _get_pad(self, size):
|
| 29 |
+
new_size = math.ceil(size / self.multiple) * self.multiple
|
| 30 |
+
pad_size = new_size - size
|
| 31 |
+
pad_size_left = pad_size // 2
|
| 32 |
+
pad_size_right = pad_size - pad_size_left
|
| 33 |
+
return pad_size_left, pad_size_right
|
| 34 |
+
|
| 35 |
+
@torch.inference_mode()
|
| 36 |
+
def forward(self, x):
|
| 37 |
+
pads = list(itertools.chain.from_iterable(self._get_pad(m) for m in x.shape[:1:-1]))
|
| 38 |
+
output = F.pad(x, pads)
|
| 39 |
+
return output
|
torch_hub/facebookresearch_dinov2_main/dinov2/hub/xray_dino/__pycache__/backbones.cpython-310.pyc
ADDED
|
Binary file (847 Bytes). View file
|
|
|
torch_hub/facebookresearch_dinov2_main/dinov2/hub/xray_dino/backbones.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the licence
|
| 4 |
+
# found in the LICENSE_XRAY_DINO_MODEL file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from typing import Union
|
| 7 |
+
|
| 8 |
+
from ..backbones import Weights, _make_dinov2_model
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def xray_dino_vitl16(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.XRAY_DINO, **kwargs):
|
| 12 |
+
"""
|
| 13 |
+
XRay-DINO ViT-L/16 model (optionally) pretrained on the XRay-DINO dataset.
|
| 14 |
+
"""
|
| 15 |
+
return _make_dinov2_model(
|
| 16 |
+
arch_name="vit_large",
|
| 17 |
+
patch_size=16,
|
| 18 |
+
img_size=512,
|
| 19 |
+
num_register_tokens=0,
|
| 20 |
+
interpolate_antialias=False,
|
| 21 |
+
interpolate_offset=0.1,
|
| 22 |
+
block_chunks=4,
|
| 23 |
+
pretrained=pretrained,
|
| 24 |
+
weights=weights,
|
| 25 |
+
hash="ad31c2b0",
|
| 26 |
+
check_hash=True,
|
| 27 |
+
**kwargs,
|
| 28 |
+
)
|
torch_hub/facebookresearch_dinov2_main/dinov2/layers/__init__.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from .dino_head import DINOHead
|
| 7 |
+
from .layer_scale import LayerScale
|
| 8 |
+
from .mlp import Mlp
|
| 9 |
+
from .patch_embed import PatchEmbed
|
| 10 |
+
from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused, SwiGLUFFNAligned
|
| 11 |
+
from .block import NestedTensorBlock, CausalAttentionBlock
|
| 12 |
+
from .attention import Attention, MemEffAttention
|
torch_hub/facebookresearch_dinov2_main/dinov2/layers/__pycache__/dino_head.cpython-310.pyc
ADDED
|
Binary file (1.99 kB). View file
|
|
|
torch_hub/facebookresearch_dinov2_main/dinov2/layers/__pycache__/layer_scale.cpython-310.pyc
ADDED
|
Binary file (1.35 kB). View file
|
|
|
torch_hub/facebookresearch_dinov2_main/dinov2/layers/__pycache__/swiglu_ffn.cpython-310.pyc
ADDED
|
Binary file (3.25 kB). View file
|
|
|
torch_hub/facebookresearch_dinov2_main/dinov2/layers/attention.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
# References:
|
| 7 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
| 8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
|
| 9 |
+
|
| 10 |
+
import logging
|
| 11 |
+
import os
|
| 12 |
+
import warnings
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
from torch import nn, Tensor
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
logger = logging.getLogger("dinov2")
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
|
| 22 |
+
try:
|
| 23 |
+
if XFORMERS_ENABLED:
|
| 24 |
+
from xformers.ops import memory_efficient_attention, unbind
|
| 25 |
+
|
| 26 |
+
XFORMERS_AVAILABLE = True
|
| 27 |
+
warnings.warn("xFormers is available (Attention)")
|
| 28 |
+
else:
|
| 29 |
+
warnings.warn("xFormers is disabled (Attention)")
|
| 30 |
+
raise ImportError
|
| 31 |
+
except ImportError:
|
| 32 |
+
XFORMERS_AVAILABLE = False
|
| 33 |
+
warnings.warn("xFormers is not available (Attention)")
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class Attention(nn.Module):
|
| 37 |
+
def __init__(
|
| 38 |
+
self,
|
| 39 |
+
dim: int,
|
| 40 |
+
num_heads: int = 8,
|
| 41 |
+
qkv_bias: bool = False,
|
| 42 |
+
proj_bias: bool = True,
|
| 43 |
+
attn_drop: float = 0.0,
|
| 44 |
+
proj_drop: float = 0.0,
|
| 45 |
+
) -> None:
|
| 46 |
+
super().__init__()
|
| 47 |
+
self.dim = dim
|
| 48 |
+
self.num_heads = num_heads
|
| 49 |
+
head_dim = dim // num_heads
|
| 50 |
+
self.scale = head_dim**-0.5
|
| 51 |
+
|
| 52 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 53 |
+
self.attn_drop = attn_drop
|
| 54 |
+
self.proj = nn.Linear(dim, dim, bias=proj_bias)
|
| 55 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 56 |
+
|
| 57 |
+
def init_weights(
|
| 58 |
+
self, init_attn_std: float | None = None, init_proj_std: float | None = None, factor: float = 1.0
|
| 59 |
+
) -> None:
|
| 60 |
+
init_attn_std = init_attn_std or (self.dim**-0.5)
|
| 61 |
+
init_proj_std = init_proj_std or init_attn_std * factor
|
| 62 |
+
nn.init.normal_(self.qkv.weight, std=init_attn_std)
|
| 63 |
+
nn.init.normal_(self.proj.weight, std=init_proj_std)
|
| 64 |
+
if self.qkv.bias is not None:
|
| 65 |
+
nn.init.zeros_(self.qkv.bias)
|
| 66 |
+
if self.proj.bias is not None:
|
| 67 |
+
nn.init.zeros_(self.proj.bias)
|
| 68 |
+
|
| 69 |
+
def forward(self, x: Tensor, is_causal: bool = False) -> Tensor:
|
| 70 |
+
B, N, C = x.shape
|
| 71 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
|
| 72 |
+
q, k, v = torch.unbind(qkv, 2)
|
| 73 |
+
q, k, v = [t.transpose(1, 2) for t in [q, k, v]]
|
| 74 |
+
x = nn.functional.scaled_dot_product_attention(
|
| 75 |
+
q, k, v, attn_mask=None, dropout_p=self.attn_drop if self.training else 0, is_causal=is_causal
|
| 76 |
+
)
|
| 77 |
+
x = x.transpose(1, 2).contiguous().view(B, N, C)
|
| 78 |
+
x = self.proj_drop(self.proj(x))
|
| 79 |
+
return x
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
class MemEffAttention(Attention):
|
| 83 |
+
def forward(self, x: Tensor, attn_bias=None) -> Tensor:
|
| 84 |
+
if not XFORMERS_AVAILABLE:
|
| 85 |
+
if attn_bias is not None:
|
| 86 |
+
raise AssertionError("xFormers is required for using nested tensors")
|
| 87 |
+
return super().forward(x)
|
| 88 |
+
|
| 89 |
+
B, N, C = x.shape
|
| 90 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
|
| 91 |
+
|
| 92 |
+
q, k, v = unbind(qkv, 2)
|
| 93 |
+
|
| 94 |
+
x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
|
| 95 |
+
x = x.reshape([B, N, C])
|
| 96 |
+
|
| 97 |
+
x = self.proj(x)
|
| 98 |
+
x = self.proj_drop(x)
|
| 99 |
+
return x
|
torch_hub/facebookresearch_dinov2_main/dinov2/layers/patch_embed.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
# References:
|
| 7 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
| 8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
|
| 9 |
+
|
| 10 |
+
from typing import Callable, Optional, Tuple, Union
|
| 11 |
+
|
| 12 |
+
from torch import Tensor
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def make_2tuple(x):
|
| 17 |
+
if isinstance(x, tuple):
|
| 18 |
+
assert len(x) == 2
|
| 19 |
+
return x
|
| 20 |
+
|
| 21 |
+
assert isinstance(x, int)
|
| 22 |
+
return (x, x)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class PatchEmbed(nn.Module):
|
| 26 |
+
"""
|
| 27 |
+
2D image to patch embedding: (B,C,H,W) -> (B,N,D)
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
img_size: Image size.
|
| 31 |
+
patch_size: Patch token size.
|
| 32 |
+
in_chans: Number of input image channels.
|
| 33 |
+
embed_dim: Number of linear projection output channels.
|
| 34 |
+
norm_layer: Normalization layer.
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
def __init__(
|
| 38 |
+
self,
|
| 39 |
+
img_size: Union[int, Tuple[int, int]] = 224,
|
| 40 |
+
patch_size: Union[int, Tuple[int, int]] = 16,
|
| 41 |
+
in_chans: int = 3,
|
| 42 |
+
embed_dim: int = 768,
|
| 43 |
+
norm_layer: Optional[Callable] = None,
|
| 44 |
+
flatten_embedding: bool = True,
|
| 45 |
+
) -> None:
|
| 46 |
+
super().__init__()
|
| 47 |
+
|
| 48 |
+
image_HW = make_2tuple(img_size)
|
| 49 |
+
patch_HW = make_2tuple(patch_size)
|
| 50 |
+
patch_grid_size = (
|
| 51 |
+
image_HW[0] // patch_HW[0],
|
| 52 |
+
image_HW[1] // patch_HW[1],
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
self.img_size = image_HW
|
| 56 |
+
self.patch_size = patch_HW
|
| 57 |
+
self.patches_resolution = patch_grid_size
|
| 58 |
+
self.num_patches = patch_grid_size[0] * patch_grid_size[1]
|
| 59 |
+
|
| 60 |
+
self.in_chans = in_chans
|
| 61 |
+
self.embed_dim = embed_dim
|
| 62 |
+
|
| 63 |
+
self.flatten_embedding = flatten_embedding
|
| 64 |
+
|
| 65 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
|
| 66 |
+
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
| 67 |
+
|
| 68 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 69 |
+
_, _, H, W = x.shape
|
| 70 |
+
patch_H, patch_W = self.patch_size
|
| 71 |
+
|
| 72 |
+
assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
|
| 73 |
+
assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
|
| 74 |
+
|
| 75 |
+
x = self.proj(x) # B C H W
|
| 76 |
+
H, W = x.size(2), x.size(3)
|
| 77 |
+
x = x.flatten(2).transpose(1, 2) # B HW C
|
| 78 |
+
x = self.norm(x)
|
| 79 |
+
if not self.flatten_embedding:
|
| 80 |
+
x = x.reshape(-1, H, W, self.embed_dim) # B H W C
|
| 81 |
+
return x
|
| 82 |
+
|
| 83 |
+
def flops(self) -> float:
|
| 84 |
+
Ho, Wo = self.patches_resolution
|
| 85 |
+
flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
|
| 86 |
+
if self.norm is not None:
|
| 87 |
+
flops += Ho * Wo * self.embed_dim
|
| 88 |
+
return flops
|
torch_hub/facebookresearch_dinov2_main/dinov2/logging/__init__.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import functools
|
| 7 |
+
import logging
|
| 8 |
+
import os
|
| 9 |
+
import sys
|
| 10 |
+
from typing import Optional
|
| 11 |
+
|
| 12 |
+
import dinov2.distributed as distributed
|
| 13 |
+
from .helpers import MetricLogger, SmoothedValue
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
# So that calling _configure_logger multiple times won't add many handlers
|
| 17 |
+
@functools.lru_cache()
|
| 18 |
+
def _configure_logger(
|
| 19 |
+
name: Optional[str] = None,
|
| 20 |
+
*,
|
| 21 |
+
level: int = logging.DEBUG,
|
| 22 |
+
output: Optional[str] = None,
|
| 23 |
+
):
|
| 24 |
+
"""
|
| 25 |
+
Configure a logger.
|
| 26 |
+
|
| 27 |
+
Adapted from Detectron2.
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
name: The name of the logger to configure.
|
| 31 |
+
level: The logging level to use.
|
| 32 |
+
output: A file name or a directory to save log. If None, will not save log file.
|
| 33 |
+
If ends with ".txt" or ".log", assumed to be a file name.
|
| 34 |
+
Otherwise, logs will be saved to `output/log.txt`.
|
| 35 |
+
|
| 36 |
+
Returns:
|
| 37 |
+
The configured logger.
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
logger = logging.getLogger(name)
|
| 41 |
+
logger.setLevel(level)
|
| 42 |
+
logger.propagate = False
|
| 43 |
+
|
| 44 |
+
# Loosely match Google glog format:
|
| 45 |
+
# [IWEF]yyyymmdd hh:mm:ss.uuuuuu threadid file:line] msg
|
| 46 |
+
# but use a shorter timestamp and include the logger name:
|
| 47 |
+
# [IWEF]yyyymmdd hh:mm:ss logger threadid file:line] msg
|
| 48 |
+
fmt_prefix = "%(levelname).1s%(asctime)s %(process)s %(name)s %(filename)s:%(lineno)s] "
|
| 49 |
+
fmt_message = "%(message)s"
|
| 50 |
+
fmt = fmt_prefix + fmt_message
|
| 51 |
+
datefmt = "%Y%m%d %H:%M:%S"
|
| 52 |
+
formatter = logging.Formatter(fmt=fmt, datefmt=datefmt)
|
| 53 |
+
|
| 54 |
+
# stdout logging for main worker only
|
| 55 |
+
if distributed.is_main_process():
|
| 56 |
+
handler = logging.StreamHandler(stream=sys.stdout)
|
| 57 |
+
handler.setLevel(logging.DEBUG)
|
| 58 |
+
handler.setFormatter(formatter)
|
| 59 |
+
logger.addHandler(handler)
|
| 60 |
+
|
| 61 |
+
# file logging for all workers
|
| 62 |
+
if output:
|
| 63 |
+
if os.path.splitext(output)[-1] in (".txt", ".log"):
|
| 64 |
+
filename = output
|
| 65 |
+
else:
|
| 66 |
+
filename = os.path.join(output, "logs", "log.txt")
|
| 67 |
+
|
| 68 |
+
if not distributed.is_main_process():
|
| 69 |
+
global_rank = distributed.get_global_rank()
|
| 70 |
+
filename = filename + ".rank{}".format(global_rank)
|
| 71 |
+
|
| 72 |
+
os.makedirs(os.path.dirname(filename), exist_ok=True)
|
| 73 |
+
|
| 74 |
+
handler = logging.StreamHandler(open(filename, "a"))
|
| 75 |
+
handler.setLevel(logging.DEBUG)
|
| 76 |
+
handler.setFormatter(formatter)
|
| 77 |
+
logger.addHandler(handler)
|
| 78 |
+
|
| 79 |
+
return logger
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def setup_logging(
|
| 83 |
+
output: Optional[str] = None,
|
| 84 |
+
*,
|
| 85 |
+
name: Optional[str] = None,
|
| 86 |
+
level: int = logging.DEBUG,
|
| 87 |
+
capture_warnings: bool = True,
|
| 88 |
+
) -> None:
|
| 89 |
+
"""
|
| 90 |
+
Setup logging.
|
| 91 |
+
|
| 92 |
+
Args:
|
| 93 |
+
output: A file name or a directory to save log files. If None, log
|
| 94 |
+
files will not be saved. If output ends with ".txt" or ".log", it
|
| 95 |
+
is assumed to be a file name.
|
| 96 |
+
Otherwise, logs will be saved to `output/log.txt`.
|
| 97 |
+
name: The name of the logger to configure, by default the root logger.
|
| 98 |
+
level: The logging level to use.
|
| 99 |
+
capture_warnings: Whether warnings should be captured as logs.
|
| 100 |
+
"""
|
| 101 |
+
logging.captureWarnings(capture_warnings)
|
| 102 |
+
_configure_logger(name, level=level, output=output)
|
torch_hub/facebookresearch_dinov2_main/dinov2/logging/helpers.py
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from collections import defaultdict, deque
|
| 7 |
+
import datetime
|
| 8 |
+
import json
|
| 9 |
+
import logging
|
| 10 |
+
import time
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
|
| 14 |
+
import dinov2.distributed as distributed
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
logger = logging.getLogger("dinov2")
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class MetricLogger(object):
|
| 21 |
+
def __init__(self, delimiter="\t", output_file=None):
|
| 22 |
+
self.meters = defaultdict(SmoothedValue)
|
| 23 |
+
self.delimiter = delimiter
|
| 24 |
+
self.output_file = output_file
|
| 25 |
+
|
| 26 |
+
def update(self, **kwargs):
|
| 27 |
+
for k, v in kwargs.items():
|
| 28 |
+
if isinstance(v, torch.Tensor):
|
| 29 |
+
v = v.item()
|
| 30 |
+
assert isinstance(v, (float, int))
|
| 31 |
+
self.meters[k].update(v)
|
| 32 |
+
|
| 33 |
+
def __getattr__(self, attr):
|
| 34 |
+
if attr in self.meters:
|
| 35 |
+
return self.meters[attr]
|
| 36 |
+
if attr in self.__dict__:
|
| 37 |
+
return self.__dict__[attr]
|
| 38 |
+
raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, attr))
|
| 39 |
+
|
| 40 |
+
def __str__(self):
|
| 41 |
+
loss_str = []
|
| 42 |
+
for name, meter in self.meters.items():
|
| 43 |
+
loss_str.append("{}: {}".format(name, str(meter)))
|
| 44 |
+
return self.delimiter.join(loss_str)
|
| 45 |
+
|
| 46 |
+
def synchronize_between_processes(self):
|
| 47 |
+
for meter in self.meters.values():
|
| 48 |
+
meter.synchronize_between_processes()
|
| 49 |
+
|
| 50 |
+
def add_meter(self, name, meter):
|
| 51 |
+
self.meters[name] = meter
|
| 52 |
+
|
| 53 |
+
def dump_in_output_file(self, iteration, iter_time, data_time):
|
| 54 |
+
if self.output_file is None or not distributed.is_main_process():
|
| 55 |
+
return
|
| 56 |
+
dict_to_dump = dict(
|
| 57 |
+
iteration=iteration,
|
| 58 |
+
iter_time=iter_time,
|
| 59 |
+
data_time=data_time,
|
| 60 |
+
)
|
| 61 |
+
dict_to_dump.update({k: v.median for k, v in self.meters.items()})
|
| 62 |
+
with open(self.output_file, "a") as f:
|
| 63 |
+
f.write(json.dumps(dict_to_dump) + "\n")
|
| 64 |
+
pass
|
| 65 |
+
|
| 66 |
+
def log_every(self, iterable, print_freq, header=None, n_iterations=None, start_iteration=0):
|
| 67 |
+
i = start_iteration
|
| 68 |
+
if not header:
|
| 69 |
+
header = ""
|
| 70 |
+
start_time = time.time()
|
| 71 |
+
end = time.time()
|
| 72 |
+
iter_time = SmoothedValue(fmt="{avg:.6f}")
|
| 73 |
+
data_time = SmoothedValue(fmt="{avg:.6f}")
|
| 74 |
+
|
| 75 |
+
if n_iterations is None:
|
| 76 |
+
n_iterations = len(iterable)
|
| 77 |
+
|
| 78 |
+
space_fmt = ":" + str(len(str(n_iterations))) + "d"
|
| 79 |
+
|
| 80 |
+
log_list = [
|
| 81 |
+
header,
|
| 82 |
+
"[{0" + space_fmt + "}/{1}]",
|
| 83 |
+
"eta: {eta}",
|
| 84 |
+
"{meters}",
|
| 85 |
+
"time: {time}",
|
| 86 |
+
"data: {data}",
|
| 87 |
+
]
|
| 88 |
+
if torch.cuda.is_available():
|
| 89 |
+
log_list += ["max mem: {memory:.0f}"]
|
| 90 |
+
|
| 91 |
+
log_msg = self.delimiter.join(log_list)
|
| 92 |
+
MB = 1024.0 * 1024.0
|
| 93 |
+
for obj in iterable:
|
| 94 |
+
data_time.update(time.time() - end)
|
| 95 |
+
yield obj
|
| 96 |
+
iter_time.update(time.time() - end)
|
| 97 |
+
if i % print_freq == 0 or i == n_iterations - 1:
|
| 98 |
+
self.dump_in_output_file(iteration=i, iter_time=iter_time.avg, data_time=data_time.avg)
|
| 99 |
+
eta_seconds = iter_time.global_avg * (n_iterations - i)
|
| 100 |
+
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
| 101 |
+
if torch.cuda.is_available():
|
| 102 |
+
logger.info(
|
| 103 |
+
log_msg.format(
|
| 104 |
+
i,
|
| 105 |
+
n_iterations,
|
| 106 |
+
eta=eta_string,
|
| 107 |
+
meters=str(self),
|
| 108 |
+
time=str(iter_time),
|
| 109 |
+
data=str(data_time),
|
| 110 |
+
memory=torch.cuda.max_memory_allocated() / MB,
|
| 111 |
+
)
|
| 112 |
+
)
|
| 113 |
+
else:
|
| 114 |
+
logger.info(
|
| 115 |
+
log_msg.format(
|
| 116 |
+
i,
|
| 117 |
+
n_iterations,
|
| 118 |
+
eta=eta_string,
|
| 119 |
+
meters=str(self),
|
| 120 |
+
time=str(iter_time),
|
| 121 |
+
data=str(data_time),
|
| 122 |
+
)
|
| 123 |
+
)
|
| 124 |
+
i += 1
|
| 125 |
+
end = time.time()
|
| 126 |
+
if i >= n_iterations:
|
| 127 |
+
break
|
| 128 |
+
total_time = time.time() - start_time
|
| 129 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
| 130 |
+
logger.info("{} Total time: {} ({:.6f} s / it)".format(header, total_time_str, total_time / n_iterations))
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
class SmoothedValue:
|
| 134 |
+
"""Track a series of values and provide access to smoothed values over a
|
| 135 |
+
window or the global series average.
|
| 136 |
+
"""
|
| 137 |
+
|
| 138 |
+
def __init__(self, window_size=20, fmt=None):
|
| 139 |
+
if fmt is None:
|
| 140 |
+
fmt = "{median:.4f} ({global_avg:.4f})"
|
| 141 |
+
self.deque = deque(maxlen=window_size)
|
| 142 |
+
self.total = 0.0
|
| 143 |
+
self.count = 0
|
| 144 |
+
self.fmt = fmt
|
| 145 |
+
|
| 146 |
+
def update(self, value, num=1):
|
| 147 |
+
self.deque.append(value)
|
| 148 |
+
self.count += num
|
| 149 |
+
self.total += value * num
|
| 150 |
+
|
| 151 |
+
def synchronize_between_processes(self):
|
| 152 |
+
"""
|
| 153 |
+
Distributed synchronization of the metric
|
| 154 |
+
Warning: does not synchronize the deque!
|
| 155 |
+
"""
|
| 156 |
+
if not distributed.is_enabled():
|
| 157 |
+
return
|
| 158 |
+
t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
|
| 159 |
+
torch.distributed.barrier()
|
| 160 |
+
torch.distributed.all_reduce(t)
|
| 161 |
+
t = t.tolist()
|
| 162 |
+
self.count = int(t[0])
|
| 163 |
+
self.total = t[1]
|
| 164 |
+
|
| 165 |
+
@property
|
| 166 |
+
def median(self):
|
| 167 |
+
d = torch.tensor(list(self.deque))
|
| 168 |
+
return d.median().item()
|
| 169 |
+
|
| 170 |
+
@property
|
| 171 |
+
def avg(self):
|
| 172 |
+
d = torch.tensor(list(self.deque), dtype=torch.float32)
|
| 173 |
+
return d.mean().item()
|
| 174 |
+
|
| 175 |
+
@property
|
| 176 |
+
def global_avg(self):
|
| 177 |
+
return self.total / self.count
|
| 178 |
+
|
| 179 |
+
@property
|
| 180 |
+
def max(self):
|
| 181 |
+
return max(self.deque)
|
| 182 |
+
|
| 183 |
+
@property
|
| 184 |
+
def value(self):
|
| 185 |
+
return self.deque[-1]
|
| 186 |
+
|
| 187 |
+
def __str__(self):
|
| 188 |
+
return self.fmt.format(
|
| 189 |
+
median=self.median,
|
| 190 |
+
avg=self.avg,
|
| 191 |
+
global_avg=self.global_avg,
|
| 192 |
+
max=self.max,
|
| 193 |
+
value=self.value,
|
| 194 |
+
)
|
torch_hub/facebookresearch_dinov2_main/dinov2/loss/__init__.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from .dino_clstoken_loss import DINOLoss
|
| 7 |
+
from .ibot_patch_loss import iBOTPatchLoss
|
| 8 |
+
from .koleo_loss import KoLeoLoss
|
torch_hub/facebookresearch_dinov2_main/dinov2/loss/ibot_patch_loss.py
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.distributed as dist
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from torch import nn
|
| 10 |
+
|
| 11 |
+
import logging
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
logger = logging.getLogger("dinov2")
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
try:
|
| 18 |
+
from xformers.ops import cross_entropy
|
| 19 |
+
|
| 20 |
+
def lossfunc(t, s, temp):
|
| 21 |
+
s = s.float()
|
| 22 |
+
t = t.float()
|
| 23 |
+
if s.ndim == 2:
|
| 24 |
+
return -cross_entropy(s.unsqueeze(0), t.unsqueeze(0), temp, bw_inplace=True).squeeze(0)
|
| 25 |
+
elif s.ndim == 3:
|
| 26 |
+
return -cross_entropy(s, t, temp, bw_inplace=True)
|
| 27 |
+
|
| 28 |
+
except ImportError:
|
| 29 |
+
|
| 30 |
+
def lossfunc(t, s, temp):
|
| 31 |
+
return torch.sum(t * F.log_softmax(s / temp, dim=-1), dim=-1)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class iBOTPatchLoss(nn.Module):
|
| 35 |
+
def __init__(self, patch_out_dim, student_temp=0.1, center_momentum=0.9):
|
| 36 |
+
super().__init__()
|
| 37 |
+
self.student_temp = student_temp
|
| 38 |
+
self.center_momentum = center_momentum
|
| 39 |
+
self.register_buffer("center", torch.zeros(1, 1, patch_out_dim))
|
| 40 |
+
self.updated = True
|
| 41 |
+
self.reduce_handle = None
|
| 42 |
+
self.len_teacher_patch_tokens = None
|
| 43 |
+
self.async_batch_center = None
|
| 44 |
+
|
| 45 |
+
@torch.no_grad()
|
| 46 |
+
def softmax_center_teacher(self, teacher_patch_tokens, teacher_temp):
|
| 47 |
+
self.apply_center_update()
|
| 48 |
+
# teacher centering and sharpening
|
| 49 |
+
#
|
| 50 |
+
# WARNING:
|
| 51 |
+
# as self.center is a float32, everything gets casted to float32 afterwards
|
| 52 |
+
#
|
| 53 |
+
# teacher_patch_tokens = teacher_patch_tokens.float()
|
| 54 |
+
# return F.softmax((teacher_patch_tokens.sub_(self.center.to(teacher_patch_tokens.dtype))).mul_(1 / teacher_temp), dim=-1)
|
| 55 |
+
|
| 56 |
+
return F.softmax((teacher_patch_tokens - self.center) / teacher_temp, dim=-1)
|
| 57 |
+
|
| 58 |
+
# this is experimental, keep everything in float16 and let's see what happens:
|
| 59 |
+
# return F.softmax((teacher_patch_tokens.sub_(self.center)) / teacher_temp, dim=-1)
|
| 60 |
+
|
| 61 |
+
@torch.no_grad()
|
| 62 |
+
def sinkhorn_knopp_teacher(self, teacher_output, teacher_temp, n_masked_patches_tensor, n_iterations=3):
|
| 63 |
+
teacher_output = teacher_output.float()
|
| 64 |
+
# world_size = dist.get_world_size() if dist.is_initialized() else 1
|
| 65 |
+
Q = torch.exp(teacher_output / teacher_temp).t() # Q is K-by-B for consistency with notations from our paper
|
| 66 |
+
# B = Q.shape[1] * world_size # number of samples to assign
|
| 67 |
+
B = n_masked_patches_tensor
|
| 68 |
+
dist.all_reduce(B)
|
| 69 |
+
K = Q.shape[0] # how many prototypes
|
| 70 |
+
|
| 71 |
+
# make the matrix sums to 1
|
| 72 |
+
sum_Q = torch.sum(Q)
|
| 73 |
+
if dist.is_initialized():
|
| 74 |
+
dist.all_reduce(sum_Q)
|
| 75 |
+
Q /= sum_Q
|
| 76 |
+
|
| 77 |
+
for it in range(n_iterations):
|
| 78 |
+
# normalize each row: total weight per prototype must be 1/K
|
| 79 |
+
sum_of_rows = torch.sum(Q, dim=1, keepdim=True)
|
| 80 |
+
if dist.is_initialized():
|
| 81 |
+
dist.all_reduce(sum_of_rows)
|
| 82 |
+
Q /= sum_of_rows
|
| 83 |
+
Q /= K
|
| 84 |
+
|
| 85 |
+
# normalize each column: total weight per sample must be 1/B
|
| 86 |
+
Q /= torch.sum(Q, dim=0, keepdim=True)
|
| 87 |
+
Q /= B
|
| 88 |
+
|
| 89 |
+
Q *= B # the columns must sum to 1 so that Q is an assignment
|
| 90 |
+
return Q.t()
|
| 91 |
+
|
| 92 |
+
def forward(self, student_patch_tokens, teacher_patch_tokens, student_masks_flat):
|
| 93 |
+
"""
|
| 94 |
+
Cross-entropy between softmax outputs of the teacher and student networks.
|
| 95 |
+
student_patch_tokens: (B, N, D) tensor
|
| 96 |
+
teacher_patch_tokens: (B, N, D) tensor
|
| 97 |
+
student_masks_flat: (B, N) tensor
|
| 98 |
+
"""
|
| 99 |
+
t = teacher_patch_tokens
|
| 100 |
+
s = student_patch_tokens
|
| 101 |
+
loss = torch.sum(t * F.log_softmax(s / self.student_temp, dim=-1), dim=-1)
|
| 102 |
+
loss = torch.sum(loss * student_masks_flat.float(), dim=-1) / student_masks_flat.sum(dim=-1).clamp(min=1.0)
|
| 103 |
+
return -loss.mean()
|
| 104 |
+
|
| 105 |
+
def forward_masked(
|
| 106 |
+
self,
|
| 107 |
+
student_patch_tokens_masked,
|
| 108 |
+
teacher_patch_tokens_masked,
|
| 109 |
+
student_masks_flat,
|
| 110 |
+
n_masked_patches=None,
|
| 111 |
+
masks_weight=None,
|
| 112 |
+
):
|
| 113 |
+
t = teacher_patch_tokens_masked
|
| 114 |
+
s = student_patch_tokens_masked
|
| 115 |
+
# loss = torch.sum(t * F.log_softmax(s / self.student_temp, dim=-1), dim=-1)
|
| 116 |
+
loss = lossfunc(t, s, self.student_temp)
|
| 117 |
+
if masks_weight is None:
|
| 118 |
+
masks_weight = (
|
| 119 |
+
(1 / student_masks_flat.sum(-1).clamp(min=1.0))
|
| 120 |
+
.unsqueeze(-1)
|
| 121 |
+
.expand_as(student_masks_flat)[student_masks_flat]
|
| 122 |
+
)
|
| 123 |
+
if n_masked_patches is not None:
|
| 124 |
+
loss = loss[:n_masked_patches]
|
| 125 |
+
loss = loss * masks_weight
|
| 126 |
+
return -loss.sum() / student_masks_flat.shape[0]
|
| 127 |
+
|
| 128 |
+
@torch.no_grad()
|
| 129 |
+
def update_center(self, teacher_patch_tokens):
|
| 130 |
+
self.reduce_center_update(teacher_patch_tokens)
|
| 131 |
+
|
| 132 |
+
@torch.no_grad()
|
| 133 |
+
def reduce_center_update(self, teacher_patch_tokens):
|
| 134 |
+
self.updated = False
|
| 135 |
+
self.len_teacher_patch_tokens = len(teacher_patch_tokens)
|
| 136 |
+
self.async_batch_center = torch.sum(teacher_patch_tokens.mean(1), dim=0, keepdim=True)
|
| 137 |
+
if dist.is_initialized():
|
| 138 |
+
self.reduce_handle = dist.all_reduce(self.async_batch_center, async_op=True)
|
| 139 |
+
|
| 140 |
+
@torch.no_grad()
|
| 141 |
+
def apply_center_update(self):
|
| 142 |
+
if self.updated is False:
|
| 143 |
+
world_size = dist.get_world_size() if dist.is_initialized() else 1
|
| 144 |
+
|
| 145 |
+
if self.reduce_handle is not None:
|
| 146 |
+
self.reduce_handle.wait()
|
| 147 |
+
_t = self.async_batch_center / (self.len_teacher_patch_tokens * world_size)
|
| 148 |
+
|
| 149 |
+
self.center = self.center * self.center_momentum + _t * (1 - self.center_momentum)
|
| 150 |
+
|
| 151 |
+
self.updated = True
|
torch_hub/facebookresearch_dinov2_main/dinov2/models/vision_transformer.py
ADDED
|
@@ -0,0 +1,428 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
# References:
|
| 7 |
+
# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
|
| 8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
|
| 9 |
+
|
| 10 |
+
from functools import partial
|
| 11 |
+
import math
|
| 12 |
+
import logging
|
| 13 |
+
from typing import Sequence, Tuple, Union, Callable
|
| 14 |
+
|
| 15 |
+
import numpy as np
|
| 16 |
+
import torch
|
| 17 |
+
import torch.nn as nn
|
| 18 |
+
import torch.utils.checkpoint
|
| 19 |
+
from torch.nn.init import trunc_normal_
|
| 20 |
+
|
| 21 |
+
from dinov2.layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
logger = logging.getLogger("dinov2")
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
|
| 28 |
+
if not depth_first and include_root:
|
| 29 |
+
fn(module=module, name=name)
|
| 30 |
+
for child_name, child_module in module.named_children():
|
| 31 |
+
child_name = ".".join((name, child_name)) if name else child_name
|
| 32 |
+
named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
|
| 33 |
+
if depth_first and include_root:
|
| 34 |
+
fn(module=module, name=name)
|
| 35 |
+
return module
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class BlockChunk(nn.ModuleList):
|
| 39 |
+
def forward(self, x):
|
| 40 |
+
for b in self:
|
| 41 |
+
x = b(x)
|
| 42 |
+
return x
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class DinoVisionTransformer(nn.Module):
|
| 46 |
+
def __init__(
|
| 47 |
+
self,
|
| 48 |
+
img_size=224,
|
| 49 |
+
patch_size=16,
|
| 50 |
+
in_chans=3,
|
| 51 |
+
embed_dim=768,
|
| 52 |
+
depth=12,
|
| 53 |
+
num_heads=12,
|
| 54 |
+
mlp_ratio=4.0,
|
| 55 |
+
qkv_bias=True,
|
| 56 |
+
ffn_bias=True,
|
| 57 |
+
proj_bias=True,
|
| 58 |
+
drop_path_rate=0.0,
|
| 59 |
+
drop_path_uniform=False,
|
| 60 |
+
init_values=None, # for layerscale: None or 0 => no layerscale
|
| 61 |
+
embed_layer=PatchEmbed,
|
| 62 |
+
act_layer=nn.GELU,
|
| 63 |
+
block_fn=Block,
|
| 64 |
+
ffn_layer="mlp",
|
| 65 |
+
block_chunks=1,
|
| 66 |
+
num_register_tokens=0,
|
| 67 |
+
interpolate_antialias=False,
|
| 68 |
+
interpolate_offset=0.1,
|
| 69 |
+
channel_adaptive=False,
|
| 70 |
+
):
|
| 71 |
+
"""
|
| 72 |
+
Args:
|
| 73 |
+
img_size (int, tuple): input image size
|
| 74 |
+
patch_size (int, tuple): patch size
|
| 75 |
+
in_chans (int): number of input channels
|
| 76 |
+
embed_dim (int): embedding dimension
|
| 77 |
+
depth (int): depth of transformer
|
| 78 |
+
num_heads (int): number of attention heads
|
| 79 |
+
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
|
| 80 |
+
qkv_bias (bool): enable bias for qkv if True
|
| 81 |
+
proj_bias (bool): enable bias for proj in attn if True
|
| 82 |
+
ffn_bias (bool): enable bias for ffn if True
|
| 83 |
+
drop_path_rate (float): stochastic depth rate
|
| 84 |
+
drop_path_uniform (bool): apply uniform drop rate across blocks
|
| 85 |
+
weight_init (str): weight init scheme
|
| 86 |
+
init_values (float): layer-scale init values
|
| 87 |
+
embed_layer (nn.Module): patch embedding layer
|
| 88 |
+
act_layer (nn.Module): MLP activation layer
|
| 89 |
+
block_fn (nn.Module): transformer block class
|
| 90 |
+
ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
|
| 91 |
+
block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
|
| 92 |
+
num_register_tokens: (int) number of extra cls tokens (so-called "registers")
|
| 93 |
+
interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings
|
| 94 |
+
interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings
|
| 95 |
+
"""
|
| 96 |
+
super().__init__()
|
| 97 |
+
norm_layer = partial(nn.LayerNorm, eps=1e-6)
|
| 98 |
+
|
| 99 |
+
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
| 100 |
+
self.num_tokens = 1
|
| 101 |
+
self.n_blocks = depth
|
| 102 |
+
self.num_heads = num_heads
|
| 103 |
+
self.patch_size = patch_size
|
| 104 |
+
self.num_register_tokens = num_register_tokens
|
| 105 |
+
self.interpolate_antialias = interpolate_antialias
|
| 106 |
+
self.interpolate_offset = interpolate_offset
|
| 107 |
+
self.bag_of_channels = channel_adaptive
|
| 108 |
+
|
| 109 |
+
self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
| 110 |
+
num_patches = self.patch_embed.num_patches
|
| 111 |
+
|
| 112 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
| 113 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
|
| 114 |
+
assert num_register_tokens >= 0
|
| 115 |
+
self.register_tokens = (
|
| 116 |
+
nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
if drop_path_uniform is True:
|
| 120 |
+
dpr = [drop_path_rate] * depth
|
| 121 |
+
else:
|
| 122 |
+
dpr = np.linspace(0, drop_path_rate, depth).tolist() # stochastic depth decay rule
|
| 123 |
+
|
| 124 |
+
if ffn_layer == "mlp":
|
| 125 |
+
logger.info("using MLP layer as FFN")
|
| 126 |
+
ffn_layer = Mlp
|
| 127 |
+
elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
|
| 128 |
+
logger.info("using SwiGLU layer as FFN")
|
| 129 |
+
ffn_layer = SwiGLUFFNFused
|
| 130 |
+
elif ffn_layer == "identity":
|
| 131 |
+
logger.info("using Identity layer as FFN")
|
| 132 |
+
|
| 133 |
+
def f(*args, **kwargs):
|
| 134 |
+
return nn.Identity()
|
| 135 |
+
|
| 136 |
+
ffn_layer = f
|
| 137 |
+
else:
|
| 138 |
+
raise NotImplementedError
|
| 139 |
+
|
| 140 |
+
blocks_list = [
|
| 141 |
+
block_fn(
|
| 142 |
+
dim=embed_dim,
|
| 143 |
+
num_heads=num_heads,
|
| 144 |
+
mlp_ratio=mlp_ratio,
|
| 145 |
+
qkv_bias=qkv_bias,
|
| 146 |
+
proj_bias=proj_bias,
|
| 147 |
+
ffn_bias=ffn_bias,
|
| 148 |
+
drop_path=dpr[i],
|
| 149 |
+
norm_layer=norm_layer,
|
| 150 |
+
act_layer=act_layer,
|
| 151 |
+
ffn_layer=ffn_layer,
|
| 152 |
+
init_values=init_values,
|
| 153 |
+
)
|
| 154 |
+
for i in range(depth)
|
| 155 |
+
]
|
| 156 |
+
if block_chunks > 0:
|
| 157 |
+
self.chunked_blocks = True
|
| 158 |
+
chunked_blocks = []
|
| 159 |
+
chunksize = depth // block_chunks
|
| 160 |
+
for i in range(0, depth, chunksize):
|
| 161 |
+
# this is to keep the block index consistent if we chunk the block list
|
| 162 |
+
chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
|
| 163 |
+
self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
|
| 164 |
+
else:
|
| 165 |
+
self.chunked_blocks = False
|
| 166 |
+
self.blocks = nn.ModuleList(blocks_list)
|
| 167 |
+
|
| 168 |
+
self.norm = norm_layer(embed_dim)
|
| 169 |
+
self.head = nn.Identity()
|
| 170 |
+
|
| 171 |
+
self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
|
| 172 |
+
|
| 173 |
+
self.init_weights()
|
| 174 |
+
|
| 175 |
+
def init_weights(self):
|
| 176 |
+
trunc_normal_(self.pos_embed, std=0.02)
|
| 177 |
+
nn.init.normal_(self.cls_token, std=1e-6)
|
| 178 |
+
if self.register_tokens is not None:
|
| 179 |
+
nn.init.normal_(self.register_tokens, std=1e-6)
|
| 180 |
+
named_apply(init_weights_vit_timm, self)
|
| 181 |
+
|
| 182 |
+
def interpolate_pos_encoding(self, x, w, h):
|
| 183 |
+
previous_dtype = x.dtype
|
| 184 |
+
npatch = x.shape[1] - 1
|
| 185 |
+
N = self.pos_embed.shape[1] - 1
|
| 186 |
+
if npatch == N and w == h:
|
| 187 |
+
return self.pos_embed
|
| 188 |
+
pos_embed = self.pos_embed.float()
|
| 189 |
+
class_pos_embed = pos_embed[:, 0]
|
| 190 |
+
patch_pos_embed = pos_embed[:, 1:]
|
| 191 |
+
dim = x.shape[-1]
|
| 192 |
+
w0 = w // self.patch_size
|
| 193 |
+
h0 = h // self.patch_size
|
| 194 |
+
M = int(math.sqrt(N)) # Recover the number of patches in each dimension
|
| 195 |
+
assert N == M * M
|
| 196 |
+
kwargs = {}
|
| 197 |
+
if self.interpolate_offset:
|
| 198 |
+
# Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8
|
| 199 |
+
# Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors
|
| 200 |
+
sx = float(w0 + self.interpolate_offset) / M
|
| 201 |
+
sy = float(h0 + self.interpolate_offset) / M
|
| 202 |
+
kwargs["scale_factor"] = (sx, sy)
|
| 203 |
+
else:
|
| 204 |
+
# Simply specify an output size instead of a scale factor
|
| 205 |
+
kwargs["size"] = (w0, h0)
|
| 206 |
+
patch_pos_embed = nn.functional.interpolate(
|
| 207 |
+
patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2),
|
| 208 |
+
mode="bicubic",
|
| 209 |
+
antialias=self.interpolate_antialias,
|
| 210 |
+
**kwargs,
|
| 211 |
+
)
|
| 212 |
+
assert (w0, h0) == patch_pos_embed.shape[-2:]
|
| 213 |
+
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
| 214 |
+
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
|
| 215 |
+
|
| 216 |
+
def prepare_tokens_with_masks(self, x, masks=None):
|
| 217 |
+
B, nc, w, h = x.shape
|
| 218 |
+
x = self.patch_embed(x)
|
| 219 |
+
if masks is not None:
|
| 220 |
+
x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
|
| 221 |
+
|
| 222 |
+
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
| 223 |
+
x = x + self.interpolate_pos_encoding(x, w, h)
|
| 224 |
+
|
| 225 |
+
if self.register_tokens is not None:
|
| 226 |
+
x = torch.cat(
|
| 227 |
+
(
|
| 228 |
+
x[:, :1],
|
| 229 |
+
self.register_tokens.expand(x.shape[0], -1, -1),
|
| 230 |
+
x[:, 1:],
|
| 231 |
+
),
|
| 232 |
+
dim=1,
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
return x
|
| 236 |
+
|
| 237 |
+
def forward_features_list(self, x_list, masks_list):
|
| 238 |
+
x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
|
| 239 |
+
for blk in self.blocks:
|
| 240 |
+
x = blk(x)
|
| 241 |
+
|
| 242 |
+
all_x = x
|
| 243 |
+
output = []
|
| 244 |
+
for x, masks in zip(all_x, masks_list):
|
| 245 |
+
x_norm = self.norm(x)
|
| 246 |
+
output.append(
|
| 247 |
+
{
|
| 248 |
+
"x_norm_clstoken": x_norm[:, 0],
|
| 249 |
+
"x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
|
| 250 |
+
"x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
|
| 251 |
+
"x_prenorm": x,
|
| 252 |
+
"masks": masks,
|
| 253 |
+
}
|
| 254 |
+
)
|
| 255 |
+
return output
|
| 256 |
+
|
| 257 |
+
def forward_features(self, x, masks=None):
|
| 258 |
+
if isinstance(x, list):
|
| 259 |
+
return self.forward_features_list(x, masks)
|
| 260 |
+
|
| 261 |
+
x = self.prepare_tokens_with_masks(x, masks)
|
| 262 |
+
|
| 263 |
+
for blk in self.blocks:
|
| 264 |
+
x = blk(x)
|
| 265 |
+
|
| 266 |
+
x_norm = self.norm(x)
|
| 267 |
+
return {
|
| 268 |
+
"x_norm_clstoken": x_norm[:, 0],
|
| 269 |
+
"x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
|
| 270 |
+
"x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
|
| 271 |
+
"x_prenorm": x,
|
| 272 |
+
"masks": masks,
|
| 273 |
+
}
|
| 274 |
+
|
| 275 |
+
def _get_intermediate_layers_not_chunked(self, x, n=1):
|
| 276 |
+
x = self.prepare_tokens_with_masks(x)
|
| 277 |
+
# If n is an int, take the n last blocks. If it's a list, take them
|
| 278 |
+
output, total_block_len = [], len(self.blocks)
|
| 279 |
+
blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
|
| 280 |
+
for i, blk in enumerate(self.blocks):
|
| 281 |
+
x = blk(x)
|
| 282 |
+
if i in blocks_to_take:
|
| 283 |
+
output.append(x)
|
| 284 |
+
assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
|
| 285 |
+
return output
|
| 286 |
+
|
| 287 |
+
def _get_intermediate_layers_chunked(self, x, n=1):
|
| 288 |
+
x = self.prepare_tokens_with_masks(x)
|
| 289 |
+
output, i, total_block_len = [], 0, len(self.blocks[-1])
|
| 290 |
+
# If n is an int, take the n last blocks. If it's a list, take them
|
| 291 |
+
blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
|
| 292 |
+
for block_chunk in self.blocks:
|
| 293 |
+
for blk in block_chunk[i:]: # Passing the nn.Identity()
|
| 294 |
+
x = blk(x)
|
| 295 |
+
if i in blocks_to_take:
|
| 296 |
+
output.append(x)
|
| 297 |
+
i += 1
|
| 298 |
+
assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
|
| 299 |
+
return output
|
| 300 |
+
|
| 301 |
+
def get_intermediate_layers(
|
| 302 |
+
self,
|
| 303 |
+
x: torch.Tensor,
|
| 304 |
+
n: Union[int, Sequence] = 1, # Layers or n last layers to take
|
| 305 |
+
reshape: bool = False,
|
| 306 |
+
return_class_token: bool = False,
|
| 307 |
+
norm=True,
|
| 308 |
+
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
|
| 309 |
+
|
| 310 |
+
if self.bag_of_channels:
|
| 311 |
+
B, C, H, W = x.shape
|
| 312 |
+
x = x.reshape(B * C, 1, H, W) # passing channels to batch dimension to get encodings for each channel
|
| 313 |
+
|
| 314 |
+
if self.chunked_blocks:
|
| 315 |
+
outputs = self._get_intermediate_layers_chunked(x, n)
|
| 316 |
+
else:
|
| 317 |
+
outputs = self._get_intermediate_layers_not_chunked(x, n)
|
| 318 |
+
if norm:
|
| 319 |
+
outputs = [self.norm(out) for out in outputs]
|
| 320 |
+
class_tokens = [out[:, 0] for out in outputs]
|
| 321 |
+
outputs = [out[:, 1 + self.num_register_tokens :] for out in outputs]
|
| 322 |
+
if reshape:
|
| 323 |
+
B, _, w, h = x.shape
|
| 324 |
+
outputs = [
|
| 325 |
+
out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
|
| 326 |
+
for out in outputs
|
| 327 |
+
]
|
| 328 |
+
|
| 329 |
+
if self.bag_of_channels:
|
| 330 |
+
output = tuple(zip(outputs, class_tokens))
|
| 331 |
+
output = list(
|
| 332 |
+
zip(*output)
|
| 333 |
+
) # unzip the tuple: (list of patch_tokens per block, list of class tokens per block)
|
| 334 |
+
patch_tokens_per_block = output[0] # [BLOCK1, BLOCK2, ...] where BLOCK1.shape: B*C, N, D
|
| 335 |
+
cls_tokens_per_block = output[1] # [BLOCK1, BLOCK2, ...] where BLOCK1.shape: B*C, D
|
| 336 |
+
patch_tokens_per_block = [
|
| 337 |
+
patch_tokens.reshape(B, C, patch_tokens.shape[-2], patch_tokens.shape[-1])
|
| 338 |
+
for patch_tokens in patch_tokens_per_block
|
| 339 |
+
] # [BLOCK1, BLOCK2, ...] where BLOCK1.shape: B, C, N, D
|
| 340 |
+
cls_tokens_per_block = [cls_tokens.reshape(B, -1) for cls_tokens in cls_tokens_per_block]
|
| 341 |
+
output = tuple(zip(patch_tokens_per_block, cls_tokens_per_block))
|
| 342 |
+
return output
|
| 343 |
+
|
| 344 |
+
if return_class_token:
|
| 345 |
+
return tuple(zip(outputs, class_tokens))
|
| 346 |
+
return tuple(outputs)
|
| 347 |
+
|
| 348 |
+
def forward(self, *args, is_training=False, **kwargs):
|
| 349 |
+
ret = self.forward_features(*args, **kwargs)
|
| 350 |
+
if is_training:
|
| 351 |
+
return ret
|
| 352 |
+
else:
|
| 353 |
+
return self.head(ret["x_norm_clstoken"])
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
def init_weights_vit_timm(module: nn.Module, name: str = ""):
|
| 357 |
+
"""ViT weight initialization, original timm impl (for reproducibility)"""
|
| 358 |
+
if isinstance(module, nn.Linear):
|
| 359 |
+
trunc_normal_(module.weight, std=0.02)
|
| 360 |
+
if module.bias is not None:
|
| 361 |
+
nn.init.zeros_(module.bias)
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
def vit_small(patch_size=16, num_register_tokens=0, in_chans=3, channel_adaptive=False, **kwargs):
|
| 365 |
+
model = DinoVisionTransformer(
|
| 366 |
+
patch_size=patch_size,
|
| 367 |
+
embed_dim=384,
|
| 368 |
+
depth=12,
|
| 369 |
+
num_heads=6,
|
| 370 |
+
mlp_ratio=4,
|
| 371 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
| 372 |
+
num_register_tokens=num_register_tokens,
|
| 373 |
+
in_chans=in_chans,
|
| 374 |
+
channel_adaptive=channel_adaptive,
|
| 375 |
+
**kwargs,
|
| 376 |
+
)
|
| 377 |
+
return model
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
def vit_base(patch_size=16, num_register_tokens=0, in_chans=3, channel_adaptive=False, **kwargs):
|
| 381 |
+
model = DinoVisionTransformer(
|
| 382 |
+
patch_size=patch_size,
|
| 383 |
+
embed_dim=768,
|
| 384 |
+
depth=12,
|
| 385 |
+
num_heads=12,
|
| 386 |
+
mlp_ratio=4,
|
| 387 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
| 388 |
+
num_register_tokens=num_register_tokens,
|
| 389 |
+
in_chans=in_chans,
|
| 390 |
+
channel_adaptive=channel_adaptive,
|
| 391 |
+
**kwargs,
|
| 392 |
+
)
|
| 393 |
+
return model
|
| 394 |
+
|
| 395 |
+
|
| 396 |
+
def vit_large(patch_size=16, num_register_tokens=0, in_chans=3, channel_adaptive=False, **kwargs):
|
| 397 |
+
model = DinoVisionTransformer(
|
| 398 |
+
patch_size=patch_size,
|
| 399 |
+
embed_dim=1024,
|
| 400 |
+
depth=24,
|
| 401 |
+
num_heads=16,
|
| 402 |
+
mlp_ratio=4,
|
| 403 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
| 404 |
+
num_register_tokens=num_register_tokens,
|
| 405 |
+
in_chans=in_chans,
|
| 406 |
+
channel_adaptive=channel_adaptive,
|
| 407 |
+
**kwargs,
|
| 408 |
+
)
|
| 409 |
+
return model
|
| 410 |
+
|
| 411 |
+
|
| 412 |
+
def vit_giant2(patch_size=16, num_register_tokens=0, in_chans=3, channel_adaptive=False, **kwargs):
|
| 413 |
+
"""
|
| 414 |
+
Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
|
| 415 |
+
"""
|
| 416 |
+
model = DinoVisionTransformer(
|
| 417 |
+
patch_size=patch_size,
|
| 418 |
+
embed_dim=1536,
|
| 419 |
+
depth=40,
|
| 420 |
+
num_heads=24,
|
| 421 |
+
mlp_ratio=4,
|
| 422 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
| 423 |
+
num_register_tokens=num_register_tokens,
|
| 424 |
+
in_chans=in_chans,
|
| 425 |
+
channel_adaptive=channel_adaptive,
|
| 426 |
+
**kwargs,
|
| 427 |
+
)
|
| 428 |
+
return model
|
torch_hub/facebookresearch_dinov2_main/dinov2/run/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
torch_hub/facebookresearch_dinov2_main/dinov2/run/eval/cell_dino/knn.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the CC-by-NC licence,
|
| 4 |
+
# found in the LICENSE_CELL_DINO_CODE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import logging
|
| 7 |
+
import os
|
| 8 |
+
import sys
|
| 9 |
+
|
| 10 |
+
from dinov2.eval.cell_dino.knn import get_args_parser as get_knn_args_parser
|
| 11 |
+
from dinov2.logging import setup_logging
|
| 12 |
+
from dinov2.run.submit import get_args_parser, submit_jobs
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
logger = logging.getLogger("dinov2")
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class Evaluator:
|
| 19 |
+
def __init__(self, args):
|
| 20 |
+
self.args = args
|
| 21 |
+
|
| 22 |
+
def __call__(self):
|
| 23 |
+
from dinov2.eval.cell_dino.knn import main as knn_main
|
| 24 |
+
|
| 25 |
+
self._setup_args()
|
| 26 |
+
knn_main(self.args)
|
| 27 |
+
|
| 28 |
+
def checkpoint(self):
|
| 29 |
+
import submitit
|
| 30 |
+
|
| 31 |
+
logger.info(f"Requeuing {self.args}")
|
| 32 |
+
empty = type(self)(self.args)
|
| 33 |
+
return submitit.helpers.DelayedSubmission(empty)
|
| 34 |
+
|
| 35 |
+
def _setup_args(self):
|
| 36 |
+
import submitit
|
| 37 |
+
|
| 38 |
+
job_env = submitit.JobEnvironment()
|
| 39 |
+
self.args.output_dir = self.args.output_dir.replace("%j", str(job_env.job_id))
|
| 40 |
+
logger.info(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}")
|
| 41 |
+
logger.info(f"Args: {self.args}")
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def main():
|
| 45 |
+
description = "Submitit launcher for k-NN Cell-DINO and Channel-Adaptive DINO evaluation"
|
| 46 |
+
knn_args_parser = get_knn_args_parser(add_help=False)
|
| 47 |
+
parents = [knn_args_parser]
|
| 48 |
+
args_parser = get_args_parser(description=description, parents=parents)
|
| 49 |
+
args = args_parser.parse_args()
|
| 50 |
+
|
| 51 |
+
setup_logging()
|
| 52 |
+
|
| 53 |
+
assert os.path.exists(args.config_file), "Configuration file does not exist!"
|
| 54 |
+
submit_jobs(Evaluator, args, name="dinov2:knn Cell-DINO")
|
| 55 |
+
return 0
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
if __name__ == "__main__":
|
| 59 |
+
sys.exit(main())
|
torch_hub/facebookresearch_dinov2_main/dinov2/run/eval/cell_dino/linear.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the CC-by-NC licence,
|
| 4 |
+
# found in the LICENSE_CELL_DINO_CODE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import logging
|
| 7 |
+
import os
|
| 8 |
+
import sys
|
| 9 |
+
|
| 10 |
+
from dinov2.eval.cell_dino.linear import get_args_parser as get_linear_args_parser
|
| 11 |
+
from dinov2.logging import setup_logging
|
| 12 |
+
from dinov2.run.submit import get_args_parser, submit_jobs
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
logger = logging.getLogger("dinov2")
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class Evaluator:
|
| 19 |
+
def __init__(self, args):
|
| 20 |
+
self.args = args
|
| 21 |
+
|
| 22 |
+
def __call__(self):
|
| 23 |
+
from dinov2.eval.cell_dino.linear import main as linear_main
|
| 24 |
+
|
| 25 |
+
self._setup_args()
|
| 26 |
+
linear_main(self.args)
|
| 27 |
+
|
| 28 |
+
def checkpoint(self):
|
| 29 |
+
import submitit
|
| 30 |
+
|
| 31 |
+
logger.info(f"Requeuing {self.args}")
|
| 32 |
+
empty = type(self)(self.args)
|
| 33 |
+
return submitit.helpers.DelayedSubmission(empty)
|
| 34 |
+
|
| 35 |
+
def _setup_args(self):
|
| 36 |
+
import submitit
|
| 37 |
+
|
| 38 |
+
job_env = submitit.JobEnvironment()
|
| 39 |
+
self.args.output_dir = self.args.output_dir.replace("%j", str(job_env.job_id))
|
| 40 |
+
logger.info(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}")
|
| 41 |
+
logger.info(f"Args: {self.args}")
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def main():
|
| 45 |
+
description = "Submitit launcher for DINOv2 linear Cell-DINO and Channel-Adaptive DINO evaluation"
|
| 46 |
+
linear_args_parser = get_linear_args_parser(add_help=False)
|
| 47 |
+
parents = [linear_args_parser]
|
| 48 |
+
args_parser = get_args_parser(description=description, parents=parents)
|
| 49 |
+
args = args_parser.parse_args()
|
| 50 |
+
|
| 51 |
+
setup_logging()
|
| 52 |
+
|
| 53 |
+
assert os.path.exists(args.config_file), "Configuration file does not exist!"
|
| 54 |
+
submit_jobs(Evaluator, args, name="dinov2:linear Cell-DINO")
|
| 55 |
+
return 0
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
if __name__ == "__main__":
|
| 59 |
+
sys.exit(main())
|
torch_hub/facebookresearch_dinov2_main/dinov2/run/submit.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import argparse
|
| 7 |
+
import logging
|
| 8 |
+
import os
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from typing import List, Optional
|
| 11 |
+
|
| 12 |
+
import submitit
|
| 13 |
+
|
| 14 |
+
from dinov2.utils.cluster import (
|
| 15 |
+
get_slurm_executor_parameters,
|
| 16 |
+
get_slurm_partition,
|
| 17 |
+
get_user_checkpoint_path,
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
logger = logging.getLogger("dinov2")
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def get_args_parser(
|
| 25 |
+
description: Optional[str] = None,
|
| 26 |
+
parents: Optional[List[argparse.ArgumentParser]] = None,
|
| 27 |
+
add_help: bool = True,
|
| 28 |
+
) -> argparse.ArgumentParser:
|
| 29 |
+
parents = parents or []
|
| 30 |
+
slurm_partition = get_slurm_partition()
|
| 31 |
+
parser = argparse.ArgumentParser(
|
| 32 |
+
description=description,
|
| 33 |
+
parents=parents,
|
| 34 |
+
add_help=add_help,
|
| 35 |
+
)
|
| 36 |
+
parser.add_argument(
|
| 37 |
+
"--ngpus",
|
| 38 |
+
"--gpus",
|
| 39 |
+
"--gpus-per-node",
|
| 40 |
+
default=8,
|
| 41 |
+
type=int,
|
| 42 |
+
help="Number of GPUs to request on each node",
|
| 43 |
+
)
|
| 44 |
+
parser.add_argument(
|
| 45 |
+
"--nodes",
|
| 46 |
+
"--nnodes",
|
| 47 |
+
default=1,
|
| 48 |
+
type=int,
|
| 49 |
+
help="Number of nodes to request",
|
| 50 |
+
)
|
| 51 |
+
parser.add_argument(
|
| 52 |
+
"--timeout",
|
| 53 |
+
default=2800,
|
| 54 |
+
type=int,
|
| 55 |
+
help="Duration of the job",
|
| 56 |
+
)
|
| 57 |
+
parser.add_argument(
|
| 58 |
+
"--partition",
|
| 59 |
+
default=slurm_partition,
|
| 60 |
+
type=str,
|
| 61 |
+
help="Partition where to submit",
|
| 62 |
+
)
|
| 63 |
+
parser.add_argument(
|
| 64 |
+
"--use-volta32",
|
| 65 |
+
action="store_true",
|
| 66 |
+
help="Request V100-32GB GPUs",
|
| 67 |
+
)
|
| 68 |
+
parser.add_argument(
|
| 69 |
+
"--comment",
|
| 70 |
+
default="",
|
| 71 |
+
type=str,
|
| 72 |
+
help="Comment to pass to scheduler, e.g. priority message",
|
| 73 |
+
)
|
| 74 |
+
parser.add_argument(
|
| 75 |
+
"--exclude",
|
| 76 |
+
default="",
|
| 77 |
+
type=str,
|
| 78 |
+
help="Nodes to exclude",
|
| 79 |
+
)
|
| 80 |
+
return parser
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def get_shared_folder() -> Path:
|
| 84 |
+
user_checkpoint_path = get_user_checkpoint_path()
|
| 85 |
+
if user_checkpoint_path is None:
|
| 86 |
+
raise RuntimeError("Path to user checkpoint cannot be determined")
|
| 87 |
+
path = user_checkpoint_path / "experiments"
|
| 88 |
+
path.mkdir(exist_ok=True)
|
| 89 |
+
return path
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def submit_jobs(task_class, args, name: str):
|
| 93 |
+
if not args.output_dir:
|
| 94 |
+
args.output_dir = str(get_shared_folder() / "%j")
|
| 95 |
+
|
| 96 |
+
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
|
| 97 |
+
executor = submitit.AutoExecutor(folder=args.output_dir, slurm_max_num_timeout=30)
|
| 98 |
+
|
| 99 |
+
kwargs = {}
|
| 100 |
+
if args.use_volta32:
|
| 101 |
+
kwargs["slurm_constraint"] = "volta32gb"
|
| 102 |
+
if args.comment:
|
| 103 |
+
kwargs["slurm_comment"] = args.comment
|
| 104 |
+
if args.exclude:
|
| 105 |
+
kwargs["slurm_exclude"] = args.exclude
|
| 106 |
+
|
| 107 |
+
executor_params = get_slurm_executor_parameters(
|
| 108 |
+
nodes=args.nodes,
|
| 109 |
+
num_gpus_per_node=args.ngpus,
|
| 110 |
+
timeout_min=args.timeout, # max is 60 * 72
|
| 111 |
+
slurm_signal_delay_s=120,
|
| 112 |
+
slurm_partition=args.partition,
|
| 113 |
+
**kwargs,
|
| 114 |
+
)
|
| 115 |
+
executor.update_parameters(name=name, **executor_params)
|
| 116 |
+
|
| 117 |
+
task = task_class(args)
|
| 118 |
+
job = executor.submit(task)
|
| 119 |
+
|
| 120 |
+
logger.info(f"Submitted job_id: {job.job_id}")
|
| 121 |
+
str_output_dir = os.path.abspath(args.output_dir).replace("%j", str(job.job_id))
|
| 122 |
+
logger.info(f"Logs and checkpoints will be saved at: {str_output_dir}")
|
torch_hub/facebookresearch_dinov2_main/dinov2/run/train/train.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import logging
|
| 7 |
+
import os
|
| 8 |
+
import sys
|
| 9 |
+
|
| 10 |
+
from dinov2.logging import setup_logging
|
| 11 |
+
from dinov2.train import get_args_parser as get_train_args_parser
|
| 12 |
+
from dinov2.run.submit import get_args_parser, submit_jobs
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
logger = logging.getLogger("dinov2")
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class Trainer(object):
|
| 19 |
+
def __init__(self, args):
|
| 20 |
+
self.args = args
|
| 21 |
+
|
| 22 |
+
def __call__(self):
|
| 23 |
+
from dinov2.train import main as train_main
|
| 24 |
+
|
| 25 |
+
self._setup_args()
|
| 26 |
+
train_main(self.args)
|
| 27 |
+
|
| 28 |
+
def checkpoint(self):
|
| 29 |
+
import submitit
|
| 30 |
+
|
| 31 |
+
logger.info(f"Requeuing {self.args}")
|
| 32 |
+
empty = type(self)(self.args)
|
| 33 |
+
return submitit.helpers.DelayedSubmission(empty)
|
| 34 |
+
|
| 35 |
+
def _setup_args(self):
|
| 36 |
+
import submitit
|
| 37 |
+
|
| 38 |
+
job_env = submitit.JobEnvironment()
|
| 39 |
+
self.args.output_dir = self.args.output_dir.replace("%j", str(job_env.job_id))
|
| 40 |
+
logger.info(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}")
|
| 41 |
+
logger.info(f"Args: {self.args}")
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def main():
|
| 45 |
+
description = "Submitit launcher for DINOv2 training"
|
| 46 |
+
train_args_parser = get_train_args_parser(add_help=False)
|
| 47 |
+
parents = [train_args_parser]
|
| 48 |
+
args_parser = get_args_parser(description=description, parents=parents)
|
| 49 |
+
args = args_parser.parse_args()
|
| 50 |
+
|
| 51 |
+
setup_logging()
|
| 52 |
+
|
| 53 |
+
assert os.path.exists(args.config_file), "Configuration file does not exist!"
|
| 54 |
+
submit_jobs(Trainer, args, name="dinov2:train")
|
| 55 |
+
return 0
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
if __name__ == "__main__":
|
| 59 |
+
sys.exit(main())
|
torch_hub/facebookresearch_dinov2_main/dinov2/thirdparty/CLIP/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2021 OpenAI
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
torch_hub/facebookresearch_dinov2_main/dinov2/thirdparty/CLIP/clip/simple_tokenizer.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gzip
|
| 2 |
+
import html
|
| 3 |
+
import os
|
| 4 |
+
from functools import lru_cache
|
| 5 |
+
|
| 6 |
+
import ftfy
|
| 7 |
+
import regex as re
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@lru_cache()
|
| 11 |
+
def default_bpe():
|
| 12 |
+
return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@lru_cache()
|
| 16 |
+
def bytes_to_unicode():
|
| 17 |
+
"""
|
| 18 |
+
Returns list of utf-8 byte and a corresponding list of unicode strings.
|
| 19 |
+
The reversible bpe codes work on unicode strings.
|
| 20 |
+
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
|
| 21 |
+
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
|
| 22 |
+
This is a signficant percentage of your normal, say, 32K bpe vocab.
|
| 23 |
+
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
|
| 24 |
+
And avoids mapping to whitespace/control characters the bpe code barfs on.
|
| 25 |
+
"""
|
| 26 |
+
bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
|
| 27 |
+
cs = bs[:]
|
| 28 |
+
n = 0
|
| 29 |
+
for b in range(2**8):
|
| 30 |
+
if b not in bs:
|
| 31 |
+
bs.append(b)
|
| 32 |
+
cs.append(2**8 + n)
|
| 33 |
+
n += 1
|
| 34 |
+
cs = [chr(n) for n in cs]
|
| 35 |
+
return dict(zip(bs, cs))
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def get_pairs(word):
|
| 39 |
+
"""Return set of symbol pairs in a word.
|
| 40 |
+
Word is represented as tuple of symbols (symbols being variable-length strings).
|
| 41 |
+
"""
|
| 42 |
+
pairs = set()
|
| 43 |
+
prev_char = word[0]
|
| 44 |
+
for char in word[1:]:
|
| 45 |
+
pairs.add((prev_char, char))
|
| 46 |
+
prev_char = char
|
| 47 |
+
return pairs
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def basic_clean(text):
|
| 51 |
+
text = ftfy.fix_text(text)
|
| 52 |
+
text = html.unescape(html.unescape(text))
|
| 53 |
+
return text.strip()
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def whitespace_clean(text):
|
| 57 |
+
text = re.sub(r"\s+", " ", text)
|
| 58 |
+
text = text.strip()
|
| 59 |
+
return text
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class SimpleTokenizer(object):
|
| 63 |
+
def __init__(self, bpe_path: str = default_bpe()):
|
| 64 |
+
self.byte_encoder = bytes_to_unicode()
|
| 65 |
+
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
|
| 66 |
+
merges = gzip.open(bpe_path).read().decode("utf-8").split("\n")
|
| 67 |
+
merges = merges[1 : 49152 - 256 - 2 + 1]
|
| 68 |
+
merges = [tuple(merge.split()) for merge in merges]
|
| 69 |
+
vocab = list(bytes_to_unicode().values())
|
| 70 |
+
vocab = vocab + [v + "</w>" for v in vocab]
|
| 71 |
+
for merge in merges:
|
| 72 |
+
vocab.append("".join(merge))
|
| 73 |
+
vocab.extend(["<|startoftext|>", "<|endoftext|>"])
|
| 74 |
+
self.encoder = dict(zip(vocab, range(len(vocab))))
|
| 75 |
+
self.decoder = {v: k for k, v in self.encoder.items()}
|
| 76 |
+
self.bpe_ranks = dict(zip(merges, range(len(merges))))
|
| 77 |
+
self.cache = {"<|startoftext|>": "<|startoftext|>", "<|endoftext|>": "<|endoftext|>"}
|
| 78 |
+
self.pat = re.compile(
|
| 79 |
+
r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
|
| 80 |
+
re.IGNORECASE,
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
def bpe(self, token):
|
| 84 |
+
if token in self.cache:
|
| 85 |
+
return self.cache[token]
|
| 86 |
+
word = tuple(token[:-1]) + (token[-1] + "</w>",)
|
| 87 |
+
pairs = get_pairs(word)
|
| 88 |
+
|
| 89 |
+
if not pairs:
|
| 90 |
+
return token + "</w>"
|
| 91 |
+
|
| 92 |
+
while True:
|
| 93 |
+
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
|
| 94 |
+
if bigram not in self.bpe_ranks:
|
| 95 |
+
break
|
| 96 |
+
first, second = bigram
|
| 97 |
+
new_word = []
|
| 98 |
+
i = 0
|
| 99 |
+
while i < len(word):
|
| 100 |
+
try:
|
| 101 |
+
j = word.index(first, i)
|
| 102 |
+
new_word.extend(word[i:j])
|
| 103 |
+
i = j
|
| 104 |
+
except Exception:
|
| 105 |
+
new_word.extend(word[i:])
|
| 106 |
+
break
|
| 107 |
+
|
| 108 |
+
if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
|
| 109 |
+
new_word.append(first + second)
|
| 110 |
+
i += 2
|
| 111 |
+
else:
|
| 112 |
+
new_word.append(word[i])
|
| 113 |
+
i += 1
|
| 114 |
+
new_word = tuple(new_word)
|
| 115 |
+
word = new_word
|
| 116 |
+
if len(word) == 1:
|
| 117 |
+
break
|
| 118 |
+
else:
|
| 119 |
+
pairs = get_pairs(word)
|
| 120 |
+
word = " ".join(word)
|
| 121 |
+
self.cache[token] = word
|
| 122 |
+
return word
|
| 123 |
+
|
| 124 |
+
def encode(self, text):
|
| 125 |
+
bpe_tokens = []
|
| 126 |
+
text = whitespace_clean(basic_clean(text)).lower()
|
| 127 |
+
for token in re.findall(self.pat, text):
|
| 128 |
+
token = "".join(self.byte_encoder[b] for b in token.encode("utf-8"))
|
| 129 |
+
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" "))
|
| 130 |
+
return bpe_tokens
|
| 131 |
+
|
| 132 |
+
def decode(self, tokens):
|
| 133 |
+
text = "".join([self.decoder[token] for token in tokens])
|
| 134 |
+
text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors="replace").replace("</w>", " ")
|
| 135 |
+
return text
|
torch_hub/facebookresearch_dinov2_main/dinov2/utils/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
torch_hub/facebookresearch_dinov2_main/dinov2/utils/checkpoint.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the CC-by-NC licence,
|
| 4 |
+
# found in the LICENSE_CELL_DINO_CODE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from typing import Any
|
| 7 |
+
|
| 8 |
+
from fvcore.common.checkpoint import Checkpointer, PeriodicCheckpointer
|
| 9 |
+
from torch import nn
|
| 10 |
+
|
| 11 |
+
import dinov2.distributed as dist
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class PeriodicCheckpointerWithCleanup(PeriodicCheckpointer):
|
| 15 |
+
@property
|
| 16 |
+
def does_write(self) -> bool:
|
| 17 |
+
"""See https://github.com/facebookresearch/fvcore/blob/main/fvcore/common/checkpoint.py#L114"""
|
| 18 |
+
return self.checkpointer.save_dir and self.checkpointer.save_to_disk
|
| 19 |
+
|
| 20 |
+
def save_best(self, **kwargs: Any) -> None:
|
| 21 |
+
"""Same argument as `Checkpointer.save`, to save a model named like `model_best.pth`"""
|
| 22 |
+
self.checkpointer.save(f"{self.file_prefix}_best", **kwargs)
|
| 23 |
+
|
| 24 |
+
def has_checkpoint(self) -> bool:
|
| 25 |
+
return self.checkpointer.has_checkpoint()
|
| 26 |
+
|
| 27 |
+
def get_checkpoint_file(self) -> str: # returns "" if the file does not exist
|
| 28 |
+
return self.checkpointer.get_checkpoint_file()
|
| 29 |
+
|
| 30 |
+
def load(self, path: str, checkpointables=None) -> dict[str, Any]:
|
| 31 |
+
return self.checkpointer.load(path=path, checkpointables=checkpointables)
|
| 32 |
+
|
| 33 |
+
def step(self, iteration: int, **kwargs: Any) -> None:
|
| 34 |
+
if not self.does_write: # step also removes files, so should be deactivated when object does not write
|
| 35 |
+
return
|
| 36 |
+
super().step(iteration=iteration, **kwargs)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def resume_or_load(checkpointer: Checkpointer, path: str, *, resume: bool = True) -> dict[str, Any]:
|
| 40 |
+
"""
|
| 41 |
+
If `resume` is True, this method attempts to resume from the last
|
| 42 |
+
checkpoint, if exists. Otherwise, load checkpoint from the given path.
|
| 43 |
+
Similar to Checkpointer.resume_or_load in fvcore
|
| 44 |
+
https://github.com/facebookresearch/fvcore/blob/main/fvcore/common/checkpoint.py#L208
|
| 45 |
+
but always reload checkpointables, in case we want to resume the training in a new job.
|
| 46 |
+
"""
|
| 47 |
+
if resume and checkpointer.has_checkpoint():
|
| 48 |
+
path = checkpointer.get_checkpoint_file()
|
| 49 |
+
return checkpointer.load(path)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def build_periodic_checkpointer(
|
| 53 |
+
model: nn.Module,
|
| 54 |
+
save_dir="",
|
| 55 |
+
*,
|
| 56 |
+
period: int,
|
| 57 |
+
max_iter=None,
|
| 58 |
+
max_to_keep=None,
|
| 59 |
+
**checkpointables: Any,
|
| 60 |
+
) -> PeriodicCheckpointerWithCleanup:
|
| 61 |
+
"""Util to build a `PeriodicCheckpointerWithCleanup`."""
|
| 62 |
+
checkpointer = Checkpointer(model, save_dir, **checkpointables, save_to_disk=dist.is_main_process())
|
| 63 |
+
return PeriodicCheckpointerWithCleanup(checkpointer, period, max_iter=max_iter, max_to_keep=max_to_keep)
|
torch_hub/facebookresearch_dinov2_main/dinov2/utils/cluster.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from enum import Enum
|
| 7 |
+
import os
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import Any, Dict, Optional
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class ClusterType(Enum):
|
| 13 |
+
AWS = "aws"
|
| 14 |
+
FAIR = "fair"
|
| 15 |
+
RSC = "rsc"
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def _guess_cluster_type() -> ClusterType:
|
| 19 |
+
uname = os.uname()
|
| 20 |
+
if uname.sysname == "Linux":
|
| 21 |
+
if uname.release.endswith("-aws"):
|
| 22 |
+
# Linux kernel versions on AWS instances are of the form "5.4.0-1051-aws"
|
| 23 |
+
return ClusterType.AWS
|
| 24 |
+
elif uname.nodename.startswith("rsc"):
|
| 25 |
+
# Linux kernel versions on RSC instances are standard ones but hostnames start with "rsc"
|
| 26 |
+
return ClusterType.RSC
|
| 27 |
+
|
| 28 |
+
return ClusterType.FAIR
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def get_cluster_type(cluster_type: Optional[ClusterType] = None) -> Optional[ClusterType]:
|
| 32 |
+
if cluster_type is None:
|
| 33 |
+
return _guess_cluster_type()
|
| 34 |
+
|
| 35 |
+
return cluster_type
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def get_checkpoint_path(cluster_type: Optional[ClusterType] = None) -> Optional[Path]:
|
| 39 |
+
cluster_type = get_cluster_type(cluster_type)
|
| 40 |
+
if cluster_type is None:
|
| 41 |
+
return None
|
| 42 |
+
|
| 43 |
+
CHECKPOINT_DIRNAMES = {
|
| 44 |
+
ClusterType.AWS: "checkpoints",
|
| 45 |
+
ClusterType.FAIR: "checkpoint",
|
| 46 |
+
ClusterType.RSC: "checkpoint/dino",
|
| 47 |
+
}
|
| 48 |
+
return Path("/") / CHECKPOINT_DIRNAMES[cluster_type]
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def get_user_checkpoint_path(cluster_type: Optional[ClusterType] = None) -> Optional[Path]:
|
| 52 |
+
checkpoint_path = get_checkpoint_path(cluster_type)
|
| 53 |
+
if checkpoint_path is None:
|
| 54 |
+
return None
|
| 55 |
+
|
| 56 |
+
username = os.environ.get("USER")
|
| 57 |
+
assert username is not None
|
| 58 |
+
return checkpoint_path / username
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def get_slurm_partition(cluster_type: Optional[ClusterType] = None) -> Optional[str]:
|
| 62 |
+
cluster_type = get_cluster_type(cluster_type)
|
| 63 |
+
if cluster_type is None:
|
| 64 |
+
return None
|
| 65 |
+
|
| 66 |
+
SLURM_PARTITIONS = {
|
| 67 |
+
ClusterType.AWS: "learnaccel",
|
| 68 |
+
ClusterType.FAIR: "learnaccel",
|
| 69 |
+
ClusterType.RSC: "learn",
|
| 70 |
+
}
|
| 71 |
+
return SLURM_PARTITIONS[cluster_type]
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def get_slurm_executor_parameters(
|
| 75 |
+
nodes: int, num_gpus_per_node: int, cluster_type: Optional[ClusterType] = None, **kwargs
|
| 76 |
+
) -> Dict[str, Any]:
|
| 77 |
+
# create default parameters
|
| 78 |
+
params = {
|
| 79 |
+
"mem_gb": 0, # Requests all memory on a node, see https://slurm.schedmd.com/sbatch.html
|
| 80 |
+
"gpus_per_node": num_gpus_per_node,
|
| 81 |
+
"tasks_per_node": num_gpus_per_node, # one task per GPU
|
| 82 |
+
"cpus_per_task": 10,
|
| 83 |
+
"nodes": nodes,
|
| 84 |
+
"slurm_partition": get_slurm_partition(cluster_type),
|
| 85 |
+
}
|
| 86 |
+
# apply cluster-specific adjustments
|
| 87 |
+
cluster_type = get_cluster_type(cluster_type)
|
| 88 |
+
if cluster_type == ClusterType.AWS:
|
| 89 |
+
params["cpus_per_task"] = 12
|
| 90 |
+
del params["mem_gb"]
|
| 91 |
+
elif cluster_type == ClusterType.RSC:
|
| 92 |
+
params["cpus_per_task"] = 12
|
| 93 |
+
# set additional parameters / apply overrides
|
| 94 |
+
params.update(kwargs)
|
| 95 |
+
return params
|
torch_hub/facebookresearch_dinov2_main/dinov2/utils/config.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import math
|
| 7 |
+
import logging
|
| 8 |
+
import os
|
| 9 |
+
|
| 10 |
+
from omegaconf import OmegaConf
|
| 11 |
+
|
| 12 |
+
import dinov2.distributed as distributed
|
| 13 |
+
from dinov2.logging import setup_logging
|
| 14 |
+
from dinov2.utils import utils
|
| 15 |
+
from dinov2.configs import dinov2_default_config
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
logger = logging.getLogger("dinov2")
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def apply_scaling_rules_to_cfg(cfg): # to fix
|
| 22 |
+
if cfg.optim.scaling_rule == "sqrt_wrt_1024":
|
| 23 |
+
base_lr = cfg.optim.base_lr
|
| 24 |
+
cfg.optim.lr = base_lr
|
| 25 |
+
cfg.optim.lr *= math.sqrt(cfg.train.batch_size_per_gpu * distributed.get_global_size() / 1024.0)
|
| 26 |
+
logger.info(f"sqrt scaling learning rate; base: {base_lr}, new: {cfg.optim.lr}")
|
| 27 |
+
else:
|
| 28 |
+
raise NotImplementedError
|
| 29 |
+
return cfg
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def write_config(cfg, output_dir, name="config.yaml"):
|
| 33 |
+
logger.info(OmegaConf.to_yaml(cfg))
|
| 34 |
+
saved_cfg_path = os.path.join(output_dir, name)
|
| 35 |
+
with open(saved_cfg_path, "w") as f:
|
| 36 |
+
OmegaConf.save(config=cfg, f=f)
|
| 37 |
+
return saved_cfg_path
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def get_cfg_from_args(args):
|
| 41 |
+
args.output_dir = os.path.abspath(args.output_dir)
|
| 42 |
+
args.opts += [f"train.output_dir={args.output_dir}"]
|
| 43 |
+
default_cfg = OmegaConf.create(dinov2_default_config)
|
| 44 |
+
cfg = OmegaConf.load(args.config_file)
|
| 45 |
+
cfg = OmegaConf.merge(default_cfg, cfg, OmegaConf.from_cli(args.opts))
|
| 46 |
+
return cfg
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def default_setup(args):
|
| 50 |
+
distributed.enable(overwrite=True)
|
| 51 |
+
seed = getattr(args, "seed", 0)
|
| 52 |
+
rank = distributed.get_global_rank()
|
| 53 |
+
|
| 54 |
+
global logger
|
| 55 |
+
setup_logging(output=args.output_dir, level=logging.INFO)
|
| 56 |
+
logger = logging.getLogger("dinov2")
|
| 57 |
+
|
| 58 |
+
utils.fix_random_seeds(seed + rank)
|
| 59 |
+
logger.info("git:\n {}\n".format(utils.get_sha()))
|
| 60 |
+
logger.info("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items())))
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def setup(args):
|
| 64 |
+
"""
|
| 65 |
+
Create configs and perform basic setups.
|
| 66 |
+
"""
|
| 67 |
+
cfg = get_cfg_from_args(args)
|
| 68 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 69 |
+
default_setup(args)
|
| 70 |
+
apply_scaling_rules_to_cfg(cfg)
|
| 71 |
+
write_config(cfg, args.output_dir)
|
| 72 |
+
return cfg
|
torch_hub/facebookresearch_dinov2_main/dinov2/utils/param_groups.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from collections import defaultdict
|
| 7 |
+
import logging
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
logger = logging.getLogger("dinov2")
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def get_vit_lr_decay_rate(name, lr_decay_rate=1.0, num_layers=12, force_is_backbone=False, chunked_blocks=False):
|
| 14 |
+
"""
|
| 15 |
+
Calculate lr decay rate for different ViT blocks.
|
| 16 |
+
Args:
|
| 17 |
+
name (string): parameter name.
|
| 18 |
+
lr_decay_rate (float): base lr decay rate.
|
| 19 |
+
num_layers (int): number of ViT blocks.
|
| 20 |
+
Returns:
|
| 21 |
+
lr decay rate for the given parameter.
|
| 22 |
+
"""
|
| 23 |
+
layer_id = num_layers + 1
|
| 24 |
+
if name.startswith("backbone") or force_is_backbone:
|
| 25 |
+
if (
|
| 26 |
+
".pos_embed" in name
|
| 27 |
+
or ".patch_embed" in name
|
| 28 |
+
or ".mask_token" in name
|
| 29 |
+
or ".cls_token" in name
|
| 30 |
+
or ".register_tokens" in name
|
| 31 |
+
):
|
| 32 |
+
layer_id = 0
|
| 33 |
+
elif force_is_backbone and (
|
| 34 |
+
"pos_embed" in name
|
| 35 |
+
or "patch_embed" in name
|
| 36 |
+
or "mask_token" in name
|
| 37 |
+
or "cls_token" in name
|
| 38 |
+
or "register_tokens" in name
|
| 39 |
+
):
|
| 40 |
+
layer_id = 0
|
| 41 |
+
elif ".blocks." in name and ".residual." not in name:
|
| 42 |
+
layer_id = int(name[name.find(".blocks.") :].split(".")[2]) + 1
|
| 43 |
+
elif chunked_blocks and "blocks." in name and "residual." not in name:
|
| 44 |
+
layer_id = int(name[name.find("blocks.") :].split(".")[2]) + 1
|
| 45 |
+
elif "blocks." in name and "residual." not in name:
|
| 46 |
+
layer_id = int(name[name.find("blocks.") :].split(".")[1]) + 1
|
| 47 |
+
|
| 48 |
+
return lr_decay_rate ** (num_layers + 1 - layer_id)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def get_params_groups_with_decay(model, lr_decay_rate=1.0, patch_embed_lr_mult=1.0):
|
| 52 |
+
chunked_blocks = False
|
| 53 |
+
if hasattr(model, "n_blocks"):
|
| 54 |
+
logger.info("chunked fsdp")
|
| 55 |
+
n_blocks = model.n_blocks
|
| 56 |
+
chunked_blocks = model.chunked_blocks
|
| 57 |
+
elif hasattr(model, "blocks"):
|
| 58 |
+
logger.info("first code branch")
|
| 59 |
+
n_blocks = len(model.blocks)
|
| 60 |
+
elif hasattr(model, "backbone"):
|
| 61 |
+
logger.info("second code branch")
|
| 62 |
+
n_blocks = len(model.backbone.blocks)
|
| 63 |
+
else:
|
| 64 |
+
logger.info("else code branch")
|
| 65 |
+
n_blocks = 0
|
| 66 |
+
all_param_groups = []
|
| 67 |
+
|
| 68 |
+
for name, param in model.named_parameters():
|
| 69 |
+
name = name.replace("_fsdp_wrapped_module.", "")
|
| 70 |
+
if not param.requires_grad:
|
| 71 |
+
continue
|
| 72 |
+
decay_rate = get_vit_lr_decay_rate(
|
| 73 |
+
name, lr_decay_rate, num_layers=n_blocks, force_is_backbone=n_blocks > 0, chunked_blocks=chunked_blocks
|
| 74 |
+
)
|
| 75 |
+
d = {"params": param, "is_last_layer": False, "lr_multiplier": decay_rate, "wd_multiplier": 1.0, "name": name}
|
| 76 |
+
|
| 77 |
+
if "last_layer" in name:
|
| 78 |
+
d.update({"is_last_layer": True})
|
| 79 |
+
|
| 80 |
+
if name.endswith(".bias") or "norm" in name or "gamma" in name:
|
| 81 |
+
d.update({"wd_multiplier": 0.0})
|
| 82 |
+
|
| 83 |
+
if "patch_embed" in name:
|
| 84 |
+
d.update({"lr_multiplier": d["lr_multiplier"] * patch_embed_lr_mult})
|
| 85 |
+
|
| 86 |
+
all_param_groups.append(d)
|
| 87 |
+
logger.info(f"""{name}: lr_multiplier: {d["lr_multiplier"]}, wd_multiplier: {d["wd_multiplier"]}""")
|
| 88 |
+
|
| 89 |
+
return all_param_groups
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def fuse_params_groups(all_params_groups, keys=("lr_multiplier", "wd_multiplier", "is_last_layer")):
|
| 93 |
+
fused_params_groups = defaultdict(lambda: {"params": []})
|
| 94 |
+
for d in all_params_groups:
|
| 95 |
+
identifier = ""
|
| 96 |
+
for k in keys:
|
| 97 |
+
identifier += k + str(d[k]) + "_"
|
| 98 |
+
|
| 99 |
+
for k in keys:
|
| 100 |
+
fused_params_groups[identifier][k] = d[k]
|
| 101 |
+
fused_params_groups[identifier]["params"].append(d["params"])
|
| 102 |
+
|
| 103 |
+
return fused_params_groups.values()
|
torch_hub/facebookresearch_dinov2_main/docs/README_CHANNEL_ADAPTIVE_DINO.md
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
# Scaling Channel-Adaptive Self-Supervised Learning
|
| 3 |
+
|
| 4 |
+
[[`Paper `](https://openreview.net/forum?id=pT8sgtRVAf))] [[`BibTeX`](#citing-channeladaptivedino-and-dinov2)]
|
| 5 |
+
|
| 6 |
+
**[Meta AI Research, FAIR](https://ai.facebook.com/research/)**
|
| 7 |
+
|
| 8 |
+
Alice V. De Lorenci, Seungeun Yi, Théo Moutakanni, Piotr Bojanowski, Camille Couprie, Juan C. Caicedo, Wolfgang M. Pernice,
|
| 9 |
+
|
| 10 |
+
with special thanks to Elouan Gardes for his contributions to the codebase.
|
| 11 |
+
|
| 12 |
+
PyTorch implementation and pretrained model for ChannelAdaptive-DINO.
|
| 13 |
+
|
| 14 |
+
The contents of this repo, including the code and model weights, are intended for research use only. It is not for use in medical procedures, including any diagnostics, treatment, or curative applications. Do not use this model for any clinical purpose or as a substitute for professional medical judgement.
|
| 15 |
+
|
| 16 |
+

|
| 17 |
+
|
| 18 |
+
## Pretrained model
|
| 19 |
+
|
| 20 |
+
You can download the model weights trained on the Extended CHAMMI dataset (combination of five cell microscopy datasets with variable numbers of channels) on torchhub.
|
| 21 |
+
|
| 22 |
+
## Installation
|
| 23 |
+
|
| 24 |
+
Follow instructions in the DINOv2 README. There are two additionnal dependencies to pandas and tifffile.
|
| 25 |
+
|
| 26 |
+
## What is included / not included
|
| 27 |
+
|
| 28 |
+
This repository includes the Bag of Channel implementation, not the Hierarchical attention approach.
|
| 29 |
+
|
| 30 |
+
## Data preparation
|
| 31 |
+
|
| 32 |
+
The CHAMMI dataset is available [here](https://github.com/chaudatascience/channel_adaptive_models).
|
| 33 |
+
|
| 34 |
+
The HPA-FoV dataset is available [here](https://www.ebi.ac.uk/biostudies/bioimages/studies/S-BIAD2443)
|
| 35 |
+
|
| 36 |
+
Content: a directory new_512_whole_images and two csv files:
|
| 37 |
+
|
| 38 |
+
"whole_images_512_test.csv"
|
| 39 |
+
|
| 40 |
+
"whole_images_512_train.csv"
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
:warning: To execute the commands provided in the next sections for training and evaluation, the `dinov2` package should be included in the Python module search path, i.e. simply prefix the command to run with `PYTHONPATH=.`.
|
| 44 |
+
|
| 45 |
+
## Training
|
| 46 |
+
|
| 47 |
+
### Fast setup: training Channel-Adaptive DINO ViT-L/16 on HPA single cell dataset
|
| 48 |
+
|
| 49 |
+
Run Channel-Adaptive DINO training on 4 A100-80GB nodes (32 GPUs) in a SLURM cluster environment with submitit:
|
| 50 |
+
|
| 51 |
+
```shell
|
| 52 |
+
python dinov2/run/train/train.py \
|
| 53 |
+
--nodes 4 \
|
| 54 |
+
--config-file dinov2/configs/train/cell_dino/vitl16_boc_hpafov.yaml \
|
| 55 |
+
--output-dir <PATH/TO/OUTPUT/DIR> \
|
| 56 |
+
train.dataset_path=HPAFoV:split=TRAIN:root=<PATH/TO/DATASET>:wildcard=SEPARATE_CHANNELS"
|
| 57 |
+
```
|
| 58 |
+
|
| 59 |
+
Training time is approximately 2 days.
|
| 60 |
+
The training code saves the weights of the teacher in the `eval` folder every 12500 iterations for evaluation.
|
| 61 |
+
This example only performs pretraining on the HPA-FoV dataset.
|
| 62 |
+
|
| 63 |
+
## Evaluation
|
| 64 |
+
|
| 65 |
+
The training code regularly saves the teacher weights. In order to evaluate the model, run the following evaluation on a single node:
|
| 66 |
+
|
| 67 |
+
### Linear Evaluation on HPAFoV
|
| 68 |
+
|
| 69 |
+
```shell
|
| 70 |
+
PYTHONPATH=.:dinov2/data python dinov2/run/eval/cell_dino/linear.py \
|
| 71 |
+
--config-file dinov2/configs/eval/cell_dino/vitl16_channel_adaptive_pretrain.yaml \
|
| 72 |
+
--pretrained-weights <PATH/TO/OUTPUT/DIR>/eval/training_359999/teacher_checkpoint.pth \
|
| 73 |
+
--output-dir <PATH/TO/OUTPUT/DIR>/eval/training_359999/linear \
|
| 74 |
+
--train-dataset HPAFoV:split=TRAIN:mode=PROTEIN_LOCALIZATION:root=<PATH/TO/DATASET> \
|
| 75 |
+
--val-dataset HPAFoV:split=VAL:mode=PROTEIN_LOCALIZATION:root=<PATH/TO/DATASET> \
|
| 76 |
+
--val-metric-type mean_per_class_multilabel_f1 \
|
| 77 |
+
--loss-type binary_cross_entropy \
|
| 78 |
+
--bag-of-channels \
|
| 79 |
+
--crop-size 384 \
|
| 80 |
+
--n-last-blocks 4 \
|
| 81 |
+
--batch-size 32 \
|
| 82 |
+
--epoch-length 145 \
|
| 83 |
+
--epochs 30 \
|
| 84 |
+
--avgpool \
|
| 85 |
+
```
|
| 86 |
+
|
| 87 |
+
### KNN classification on CHAMMI
|
| 88 |
+
|
| 89 |
+
Go to the docs directory, modifify some paths in launcher_knn_eval_on_chammi.sh and run
|
| 90 |
+
|
| 91 |
+
```shell
|
| 92 |
+
./launcher_knn_eval_on_chammi.sh WTC TASK_ONE ;
|
| 93 |
+
./launcher_knn_eval_on_chammi.sh WTC TASK_TWO ;
|
| 94 |
+
./launcher_knn_eval_on_chammi.sh HPA TASK_ONE ;
|
| 95 |
+
./launcher_knn_eval_on_chammi.sh HPA TASK_TWO ;
|
| 96 |
+
./launcher_knn_eval_on_chammi.sh HPA TASK_THREE ;
|
| 97 |
+
./launcher_knn_eval_on_chammi.sh CP TASK_ONE ;
|
| 98 |
+
./launcher_knn_eval_on_chammi.sh CP TASK_TWO ;
|
| 99 |
+
./launcher_knn_eval_on_chammi.sh CP TASK_THREE ;
|
| 100 |
+
./launcher_knn_eval_on_chammi.sh CP TASK_FOUR ;
|
| 101 |
+
```
|
| 102 |
+
|
| 103 |
+
### Linear classification on CHAMMI
|
| 104 |
+
|
| 105 |
+
Go to the docs directory, modifify some paths in launcher_CHAMMI_eval.sh and run
|
| 106 |
+
|
| 107 |
+
```shell
|
| 108 |
+
./launcher_CHAMMI_eval.sh WTC TASK_ONE ;
|
| 109 |
+
./launcher_CHAMMI_eval.sh WTC TASK_TWO ;
|
| 110 |
+
./launcher_CHAMMI_eval.sh HPA TASK_ONE ;
|
| 111 |
+
./launcher_CHAMMI_eval.sh HPA TASK_TWO ;
|
| 112 |
+
./launcher_CHAMMI_eval.sh HPA TASK_THREE ;
|
| 113 |
+
./launcher_CHAMMI_eval.sh CP TASK_ONE ;
|
| 114 |
+
./launcher_CHAMMI_eval.sh CP TASK_TWO ;
|
| 115 |
+
./launcher_CHAMMI_eval.sh CP TASK_THREE ;
|
| 116 |
+
./launcher_CHAMMI_eval.sh CP TASK_FOUR ;
|
| 117 |
+
```
|
| 118 |
+
|
| 119 |
+
| | WTC - Task 1 | WTC - Task 2 | HPA - Task 1 | HPA - Task 2 | HPA - Task 3 | CP - Task 1 | CP - Task 2 | CP - Task 3 | CP - Task 4 |
|
| 120 |
+
| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- |
|
| 121 |
+
| knn reproduced | 80.3 | 79.3 | 91.6 | 61.4 | 29.0 | 89.8 | 57.6 | 23.4 | 18.4 |
|
| 122 |
+
| knn paper | 79.4 | 79.0 | 86.6 | 59.3 | 29.6 | 92.6 | 57.6 | 22.1 | 18.5 |
|
| 123 |
+
| Linear reproduced | 89.9 | 87.9 | 92.7 | 87.2 | 66.2 | 89.9 | 59.8 | 26.6 | 32.5|
|
| 124 |
+
| Linear paper | 90.5 | 89.2 | 88.3 | 84.7 | 65.0 | 90.5 | 60.5 | 25.8 | 32.7|
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
## License
|
| 128 |
+
|
| 129 |
+
Cell-DINO code is released under the CC by NC licence See [LICENSE_CELL_DINO_CODE](LICENSE_CELL_DINO_CODE) for additional details.
|
| 130 |
+
Model weights will be released under the FAIR Non-Commercial Research License. See [LICENSE_CELL_DINO_MODELS](LICENSE_CELL_DINO_MODELS) for additional details.
|
| 131 |
+
|
| 132 |
+
## Contributing
|
| 133 |
+
|
| 134 |
+
See [contributing](CONTRIBUTING.md) and the [code of conduct](CODE_OF_CONDUCT.md).
|
| 135 |
+
|
| 136 |
+
## Citing ChannelAdaptiveDINO and DINOv2
|
| 137 |
+
|
| 138 |
+
If you find this repository useful, please consider giving a star :star: and citation :t-rex::
|
| 139 |
+
|
| 140 |
+
```
|
| 141 |
+
@misc{Delorenci2025scaling,
|
| 142 |
+
title={Scaling Channel-Adaptive Self-Supervised Learning},
|
| 143 |
+
author={V. De Lorenci, Alice and Yi, Seungeun and Moutakanni, Theo and Bojanowski, Piotr and Couprie, Camille and Caicedo, Juan C. and Pernice, Wolfgang M.},
|
| 144 |
+
journal={TMLR},
|
| 145 |
+
year={2025}
|
| 146 |
+
}
|
| 147 |
+
```
|
| 148 |
+
|
| 149 |
+
```
|
| 150 |
+
@misc{oquab2023dinov2,
|
| 151 |
+
title={DINOv2: Learning Robust Visual Features without Supervision},
|
| 152 |
+
author={Oquab, Maxime and Darcet, Timothée and Moutakanni, Theo and Vo, Huy V. and Szafraniec, Marc and Khalidov, Vasil and Fernandez, Pierre and Haziza, Daniel and Massa, Francisco and El-Nouby, Alaaeldin and Howes, Russell and Huang, Po-Yao and Xu, Hu and Sharma, Vasu and Li, Shang-Wen and Galuba, Wojciech and Rabbat, Mike and Assran, Mido and Ballas, Nicolas and Synnaeve, Gabriel and Misra, Ishan and Jegou, Herve and Mairal, Julien and Labatut, Patrick and Joulin, Armand and Bojanowski, Piotr},
|
| 153 |
+
journal={arXiv:2304.07193},
|
| 154 |
+
year={2023}
|
| 155 |
+
}
|
| 156 |
+
```
|
torch_hub/facebookresearch_dinov2_main/notebooks/cell_dino/inference.ipynb
ADDED
|
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": null,
|
| 6 |
+
"id": "6c3f1fe9-af40-4a57-aff0-99313d722f34",
|
| 7 |
+
"metadata": {},
|
| 8 |
+
"outputs": [],
|
| 9 |
+
"source": [
|
| 10 |
+
"# Copyright (c) Meta Platforms, Inc. and affiliates.\n",
|
| 11 |
+
"# This source code is licensed under the CC-by-NC licence,\n",
|
| 12 |
+
"# found in the LICENSE_CELL_DINO_CODE file in the root directory of this source tree."
|
| 13 |
+
]
|
| 14 |
+
},
|
| 15 |
+
{
|
| 16 |
+
"cell_type": "code",
|
| 17 |
+
"execution_count": null,
|
| 18 |
+
"id": "bfe8d7c4-995c-44b4-a8d1-97be2badce8c",
|
| 19 |
+
"metadata": {},
|
| 20 |
+
"outputs": [],
|
| 21 |
+
"source": [
|
| 22 |
+
"SAMPLE_IMAGES_DIR = \"sample_images/\" # path to directory with cell images.\n",
|
| 23 |
+
"REPO_DIR=\"\" # path to the dinov2 repo.\n",
|
| 24 |
+
"# Also define the urls of the pretrained models CPurl, SCurl, FOVurl used in the next cell. Instructions to get the models urls are in the README.md."
|
| 25 |
+
]
|
| 26 |
+
},
|
| 27 |
+
{
|
| 28 |
+
"cell_type": "code",
|
| 29 |
+
"execution_count": null,
|
| 30 |
+
"id": "70b2a71b-8fb2-4ec3-8cd1-21d6f770f704",
|
| 31 |
+
"metadata": {},
|
| 32 |
+
"outputs": [],
|
| 33 |
+
"source": [
|
| 34 |
+
"import torch\n",
|
| 35 |
+
"cell_dino_vits8 = torch.hub.load(REPO_DIR, 'cell_dino_cp_vits8', source='local', pretrained_url=CPurl)\n",
|
| 36 |
+
"cell_dino_vitl16_sc = torch.hub.load(REPO_DIR, 'cell_dino_hpa_vitl16', source='local', pretrained_url=SCurl)\n",
|
| 37 |
+
"cell_dino_vitl16_fov = torch.hub.load(REPO_DIR, 'cell_dino_hpa_vitl16', source='local', pretrained_url=FOVurl)\n",
|
| 38 |
+
"#channel_adaptive_dino_vitl16 = torch.hub.load(REPO_DIR, 'channel_adaptive_dino_vitl16', source='local', pretrained_url=CAurl)\n",
|
| 39 |
+
"# cell_dino_vitl14 = torch.hub.load(REPO_DIR, 'cell_dino_hpa_vitl14', source='local', pretrained_url=HRurl)"
|
| 40 |
+
]
|
| 41 |
+
},
|
| 42 |
+
{
|
| 43 |
+
"cell_type": "code",
|
| 44 |
+
"execution_count": null,
|
| 45 |
+
"id": "a6899c21-c4c2-43d4-8cdc-f6cce2c3686b",
|
| 46 |
+
"metadata": {},
|
| 47 |
+
"outputs": [],
|
| 48 |
+
"source": [
|
| 49 |
+
"import torch\n",
|
| 50 |
+
"import torchvision\n",
|
| 51 |
+
"from dinov2.hub.cell_dino.backbones import cell_dino_hpa_vitl16, cell_dino_cp_vits8\n",
|
| 52 |
+
"from functools import partial\n",
|
| 53 |
+
"from dinov2.eval.utils import ModelWithIntermediateLayers\n",
|
| 54 |
+
"\n",
|
| 55 |
+
"DEVICE = \"cuda:0\"\n",
|
| 56 |
+
"\n",
|
| 57 |
+
"class self_normalize(object):\n",
|
| 58 |
+
" def __call__(self, x):\n",
|
| 59 |
+
" x = x / 255\n",
|
| 60 |
+
" m = x.mean((-2, -1), keepdim=True)\n",
|
| 61 |
+
" s = x.std((-2, -1), unbiased=False, keepdim=True)\n",
|
| 62 |
+
" x -= m\n",
|
| 63 |
+
" x /= s + 1e-7\n",
|
| 64 |
+
" return x\n",
|
| 65 |
+
"\n",
|
| 66 |
+
"normalize = self_normalize()"
|
| 67 |
+
]
|
| 68 |
+
},
|
| 69 |
+
{
|
| 70 |
+
"cell_type": "code",
|
| 71 |
+
"execution_count": null,
|
| 72 |
+
"id": "779ecbb7-3247-4b9e-9248-25c8cbd74d24",
|
| 73 |
+
"metadata": {},
|
| 74 |
+
"outputs": [],
|
| 75 |
+
"source": [
|
| 76 |
+
"# ---------------------- Example inference on HPA-FoV dataset --------------------------\n",
|
| 77 |
+
"\n",
|
| 78 |
+
"# 1- Read one human protein atlas HPA-FoV image (4 channels)\n",
|
| 79 |
+
"img = torchvision.io.read_image(SAMPLE_IMAGES_DIR + \"HPA_FoV_00070df0-bbc3-11e8-b2bc-ac1f6b6435d0.png\")\n",
|
| 80 |
+
"\n",
|
| 81 |
+
"# 2- Normalise image as it was done for training\n",
|
| 82 |
+
"img_hpa_fov = img.unsqueeze(0).to(device=DEVICE)\n",
|
| 83 |
+
"img_hpa_fov = normalize(img_hpa_fov)\n",
|
| 84 |
+
"\n",
|
| 85 |
+
"# 3- Load model\n",
|
| 86 |
+
"cell_dino_model = cell_dino_vitl16_fov\n",
|
| 87 |
+
"cell_dino_model.to(device=DEVICE)\n",
|
| 88 |
+
"cell_dino_model.eval()\n",
|
| 89 |
+
"\n",
|
| 90 |
+
"# 4- Inference\n",
|
| 91 |
+
"features = cell_dino_model(img_hpa_fov)\n",
|
| 92 |
+
"print(features)\n",
|
| 93 |
+
"\n",
|
| 94 |
+
"# 5- [Optional] feature extractor as used for linear evaluation\n",
|
| 95 |
+
"autocast_ctx = partial(torch.cuda.amp.autocast, enabled=True, dtype=torch.float)\n",
|
| 96 |
+
"model_with_interm_layers = ModelWithIntermediateLayers(cell_dino_model, 4, autocast_ctx)\n",
|
| 97 |
+
"features_with_interm_layers = model_with_interm_layers(img_hpa_fov)"
|
| 98 |
+
]
|
| 99 |
+
},
|
| 100 |
+
{
|
| 101 |
+
"cell_type": "code",
|
| 102 |
+
"execution_count": null,
|
| 103 |
+
"id": "85e67e67-5824-4135-8ca2-f8396cf95cb2",
|
| 104 |
+
"metadata": {},
|
| 105 |
+
"outputs": [],
|
| 106 |
+
"source": [
|
| 107 |
+
"# ---------------------- Example inference on cell painting data --------------------------\n",
|
| 108 |
+
"\n",
|
| 109 |
+
"# 1- Read one cell painting image (5 channels)\n",
|
| 110 |
+
"img = torchvision.io.read_image(SAMPLE_IMAGES_DIR + \"CP_BBBC036_24277_a06_1_976@140x149.png\")\n",
|
| 111 |
+
"img5_channels = torch.zeros([1, 5, 160, 160])\n",
|
| 112 |
+
"for c in range(5):\n",
|
| 113 |
+
" img5_channels[0, c] = img[0, :, 160 * c : 160 * (c + 1)]\n",
|
| 114 |
+
"img5_channels = img5_channels.to(device=DEVICE)\n",
|
| 115 |
+
"\n",
|
| 116 |
+
"# 2- Normalise image as it was done for training\n",
|
| 117 |
+
"img5_channels = normalize(img5_channels)\n",
|
| 118 |
+
"\n",
|
| 119 |
+
"# 3- Load model\n",
|
| 120 |
+
"cell_dino_model = cell_dino_vits8\n",
|
| 121 |
+
"cell_dino_model.to(device=DEVICE)\n",
|
| 122 |
+
"cell_dino_model.eval()\n",
|
| 123 |
+
"\n",
|
| 124 |
+
"# 4- Inference\n",
|
| 125 |
+
"features = cell_dino_model(img5_channels)\n",
|
| 126 |
+
"print(features[0,0:10])"
|
| 127 |
+
]
|
| 128 |
+
},
|
| 129 |
+
{
|
| 130 |
+
"cell_type": "code",
|
| 131 |
+
"execution_count": null,
|
| 132 |
+
"id": "ada26fcc-fd27-4dbf-983c-bbe06fe04b6f",
|
| 133 |
+
"metadata": {},
|
| 134 |
+
"outputs": [],
|
| 135 |
+
"source": [
|
| 136 |
+
"# ---------------------- Example inference on HPA single cell dataset --------------------------\n",
|
| 137 |
+
"\n",
|
| 138 |
+
"# Read one human protein atlas HPA single cell image (4 channels)\n",
|
| 139 |
+
"img = torchvision.io.read_image(SAMPLE_IMAGES_DIR + \"HPA_single_cell_00285ce4-bba0-11e8-b2b9-ac1f6b6435d0_15.png\")\n",
|
| 140 |
+
"\n",
|
| 141 |
+
"# 2- Normalise image as it was done for training\n",
|
| 142 |
+
"img_hpa = img.unsqueeze(0).to(device=DEVICE)\n",
|
| 143 |
+
"img_hpa = normalize(img_hpa)\n",
|
| 144 |
+
"\n",
|
| 145 |
+
"# 3- Load model\n",
|
| 146 |
+
"cell_dino_model = cell_dino_vitl16_sc\n",
|
| 147 |
+
"cell_dino_model.to(device=DEVICE)\n",
|
| 148 |
+
"cell_dino_model.eval()\n",
|
| 149 |
+
"\n",
|
| 150 |
+
"# 4- Inference\n",
|
| 151 |
+
"features = cell_dino_model(img_hpa)\n",
|
| 152 |
+
"print(features)\n",
|
| 153 |
+
"\n",
|
| 154 |
+
"torch.save(features.cpu(), \"sample_features_hpa.pt\")"
|
| 155 |
+
]
|
| 156 |
+
}
|
| 157 |
+
],
|
| 158 |
+
"metadata": {
|
| 159 |
+
"kernelspec": {
|
| 160 |
+
"display_name": "Python (mypy310env)",
|
| 161 |
+
"language": "python",
|
| 162 |
+
"name": "mypy310env"
|
| 163 |
+
},
|
| 164 |
+
"language_info": {
|
| 165 |
+
"codemirror_mode": {
|
| 166 |
+
"name": "ipython",
|
| 167 |
+
"version": 3
|
| 168 |
+
},
|
| 169 |
+
"file_extension": ".py",
|
| 170 |
+
"mimetype": "text/x-python",
|
| 171 |
+
"name": "python",
|
| 172 |
+
"nbconvert_exporter": "python",
|
| 173 |
+
"pygments_lexer": "ipython3",
|
| 174 |
+
"version": "3.10.19"
|
| 175 |
+
}
|
| 176 |
+
},
|
| 177 |
+
"nbformat": 4,
|
| 178 |
+
"nbformat_minor": 5
|
| 179 |
+
}
|
torch_hub/facebookresearch_dinov2_main/notebooks/depth_estimation.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
torch_hub/facebookresearch_dinov2_main/notebooks/semantic_segmentation.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
torch_hub/facebookresearch_dinov2_main/scripts/cell_dino/launcher_knn_eval_on_chammi.sh
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
# 1 : modify CHANNEL_AGNOSTIC_CELL_MODEL, CHAMMI_DATA_PATH and OUTPUT_DIR below
|
| 4 |
+
# 2 : call this script with the two arguments specified below
|
| 5 |
+
|
| 6 |
+
#Arguments:
|
| 7 |
+
# $1 : dataset, e.g CP
|
| 8 |
+
# $2 : task number, e.g TASK_TWO
|
| 9 |
+
|
| 10 |
+
CHAMMI_DATA_PATH=""
|
| 11 |
+
CHANNEL_AGNOSTIC_CELL_MODEL="path_to_model/model.pth"
|
| 12 |
+
OUTPUT_DIR=YOUR_OUTPUT_PATH_$1_$2
|
| 13 |
+
|
| 14 |
+
if [ "$2" == "TASK_FOUR" ]; then
|
| 15 |
+
OTHER_ARG="--leave-one-out-dataset $CHAMMI_DATA_PATH/CP/enriched_meta.csv "
|
| 16 |
+
elif [ "$1" == "HPA" -a "$2" == "TASK_THREE" ]; then
|
| 17 |
+
OTHER_ARG="--leave-one-out-dataset $CHAMMI_DATA_PATH/CHAMMI/HPA/enriched_meta.csv "
|
| 18 |
+
else
|
| 19 |
+
OTHER_ARG=""
|
| 20 |
+
fi
|
| 21 |
+
echo $OTHER_ARG
|
| 22 |
+
|
| 23 |
+
PYTHONPATH=..:../../dinov2/data python ../../dinov2/run/eval/cell_dino/knn.py \
|
| 24 |
+
--config-file ../../dinov2/configs/eval/cell_dino/vitl16_channel_adaptive_pretrain.yaml \
|
| 25 |
+
--pretrained-weights $CHANNEL_AGNOSTIC_CELL_MODEL \
|
| 26 |
+
--output-dir $OUTPUT_DIR \
|
| 27 |
+
--train-dataset CHAMMI_$1:split=TRAIN:root=$CHAMMI_DATA_PATH \
|
| 28 |
+
--val-dataset CHAMMI_$1:split=$2:root=$CHAMMI_DATA_PATH \
|
| 29 |
+
--metric-type mean_per_class_multiclass_f1 \
|
| 30 |
+
--crop-size 224 \
|
| 31 |
+
--batch-size 32 \
|
| 32 |
+
--resize-size 256 \
|
| 33 |
+
--bag-of-channels \
|
| 34 |
+
$OTHER_ARG \
|
torch_hub/trusted_list
ADDED
|
File without changes
|