File size: 4,261 Bytes
c20d7cc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
"""Contains Dense Transformer Prediction architecture.
Implements a variant of Vision Transformers for Dense Prediction, https://arxiv.org/abs/2103.13413
For licensing see accompanying LICENSE file.
Copyright (C) 2025 Apple Inc. All Rights Reserved.
"""
from __future__ import annotations
import torch
import torch.nn as nn
from sharp.models.presets import (
MONODEPTH_ENCODER_DIMS_MAP,
MONODEPTH_HOOK_IDS_MAP,
ViTPreset,
)
from .base_encoder import BaseEncoder
from .spn_encoder import SlidingPyramidNetwork
from .vit_encoder import create_vit
def create_monodepth_encoder(
patch_encoder_preset: ViTPreset,
image_encoder_preset: ViTPreset,
use_patch_overlap: bool = True,
last_encoder: int = 256,
) -> SlidingPyramidNetwork:
"""Creates DepthDensePredictionTransformer model.
Args:
patch_encoder_preset: The preset patch encoder architecture in SPN.
image_encoder_preset: The preset image encoder architecture in SPN.
use_patch_overlap: Whether to use overlap between patches in SPN.
last_encoder: last number of encoder features.
"""
dims_encoder = [last_encoder] + MONODEPTH_ENCODER_DIMS_MAP[patch_encoder_preset]
patch_encoder_block_ids = MONODEPTH_HOOK_IDS_MAP[patch_encoder_preset]
patch_encoder = create_vit(
preset=patch_encoder_preset,
intermediate_features_ids=patch_encoder_block_ids,
# We always need to output intermediate features for assembly.
)
image_encoder = create_vit(
preset=image_encoder_preset,
intermediate_features_ids=None,
)
encoder = SlidingPyramidNetwork(
dims_encoder=dims_encoder,
patch_encoder=patch_encoder,
image_encoder=image_encoder,
use_patch_overlap=use_patch_overlap,
)
return encoder
class ProjectionModule(nn.Module):
"""Apply projection of features."""
def __init__(self, dims_in: list[int], dims_out: list[int]) -> None:
"""Initialize projection module."""
super().__init__()
if len(dims_in) != len(dims_out):
raise ValueError("Length of dims_in must be same as length of dims_out.")
self.convs = nn.ModuleList(
[nn.Conv2d(dim_in, dim_out, 1) for dim_in, dim_out in zip(dims_in, dims_out)]
)
def forward(self, encodings: list[torch.Tensor]) -> list[torch.Tensor]:
"""Apply projection module."""
if len(encodings) != len(self.convs):
raise ValueError("Number of encodings must be equal to number of projections.")
return [conv(encoding) for conv, encoding in zip(self.convs, encodings)]
class MonodepthFeatureEncoder(BaseEncoder):
"""A wrapper around monodepth network to extract features."""
def __init__(
self,
monodepth_encoder: SlidingPyramidNetwork,
output_dims: list[int] | None = None,
freeze_projection: bool = False,
) -> None:
"""Initialize MonodepthFeatureExtractor."""
super().__init__()
self.encoder = monodepth_encoder
# The monodepth network returns two feature maps for the first entry in
# backbone.encoder.dims_encoder.
monodepth_dims = self.encoder.dims_encoder
monodepth_dims = monodepth_dims
if output_dims is not None:
if not len(output_dims) == len(monodepth_dims):
raise ValueError(
"When set, number of output dimensions must be equal to output "
f"dimensions of monodepth model {len(monodepth_dims)}."
)
self.projection = ProjectionModule(monodepth_dims, output_dims)
self.output_dims = output_dims
else:
self.projection = nn.Identity()
self.output_dims = monodepth_dims
if freeze_projection:
self.projection.requires_grad_(False)
def forward(self, input_features: torch.Tensor) -> list[torch.Tensor]:
"""Extract multi-resolution features."""
encodings = self.encoder(input_features[:, :3].contiguous())
return self.projection(encodings)
def internal_resolution(self) -> int:
"""Internal resolution of the encoder."""
return self.encoder.internal_resolution()
|