|
|
"""Contains Sliding Pyramid Network architecture. |
|
|
|
|
|
For licensing see accompanying LICENSE file. |
|
|
Copyright (C) 2025 Apple Inc. All Rights Reserved. |
|
|
""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import math |
|
|
from typing import Iterable |
|
|
|
|
|
import torch |
|
|
import torch.fx |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
|
|
|
from sharp.utils.training import checkpoint_wrapper |
|
|
|
|
|
from .base_encoder import BaseEncoder |
|
|
from .vit_encoder import TimmViT |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
non_traceable_ops = ("len", "int") |
|
|
for op in non_traceable_ops: |
|
|
torch.fx.wrap(op) |
|
|
|
|
|
|
|
|
class SlidingPyramidNetwork(BaseEncoder): |
|
|
"""Sliding Pyramid Network. |
|
|
|
|
|
An encoder aimed at creating multi-resolution encodings from Vision Transformers. |
|
|
|
|
|
Reference: Bochkovskii et al. - "Depth pro: Sharp monocular metric depth in less |
|
|
than a second." (ICLR 2024) |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
dims_encoder: Iterable[int], |
|
|
patch_encoder: TimmViT, |
|
|
image_encoder: TimmViT, |
|
|
use_patch_overlap: bool = True, |
|
|
): |
|
|
"""Initialize Sliding Pyramid Network. |
|
|
|
|
|
The framework |
|
|
1. creates an image pyramid, |
|
|
2. generates overlapping patches with a sliding window at each pyramid level, |
|
|
3. creates batched encodings via vision transformer backbones, |
|
|
4. produces multi-resolution encodings. |
|
|
|
|
|
Args: |
|
|
dims_encoder: Dimensions of the encoder at different layers. |
|
|
patch_encoder: Backbone used for highres part of the pyramid. |
|
|
image_encoder: Backbone used for lowres part of the pyramid. |
|
|
use_patch_overlap: Whether to use overlap between patches in SPN. |
|
|
""" |
|
|
super().__init__() |
|
|
|
|
|
self.dim_in = patch_encoder.dim_in |
|
|
|
|
|
self.dims_encoder = list(dims_encoder) |
|
|
self.patch_encoder = patch_encoder |
|
|
self.image_encoder = image_encoder |
|
|
|
|
|
base_embed_dim = patch_encoder.embed_dim |
|
|
lowres_embed_dim = image_encoder.embed_dim |
|
|
self.patch_size = patch_encoder.internal_resolution() |
|
|
|
|
|
self.grad_checkpointing = False |
|
|
self.use_patch_overlap = use_patch_overlap |
|
|
|
|
|
|
|
|
self.patch_intermediate_features_ids = patch_encoder.intermediate_features_ids |
|
|
if ( |
|
|
not isinstance(self.patch_intermediate_features_ids, list) |
|
|
or not len(self.patch_intermediate_features_ids) == 4 |
|
|
): |
|
|
raise ValueError("Patch intermediate feature ids must be a 4-item list.") |
|
|
|
|
|
self.image_intermediate_features_ids = image_encoder.intermediate_features_ids |
|
|
|
|
|
def _create_project_upsample_block( |
|
|
dim_in: int, |
|
|
dim_out: int, |
|
|
upsample_layers: int, |
|
|
dim_intermediate=None, |
|
|
) -> nn.Module: |
|
|
if dim_intermediate is None: |
|
|
dim_intermediate = dim_out |
|
|
|
|
|
blocks = [ |
|
|
nn.Conv2d( |
|
|
in_channels=dim_in, |
|
|
out_channels=dim_intermediate, |
|
|
kernel_size=1, |
|
|
stride=1, |
|
|
padding=0, |
|
|
bias=False, |
|
|
) |
|
|
] |
|
|
|
|
|
|
|
|
blocks += [ |
|
|
nn.ConvTranspose2d( |
|
|
in_channels=dim_intermediate if i == 0 else dim_out, |
|
|
out_channels=dim_out, |
|
|
kernel_size=2, |
|
|
stride=2, |
|
|
padding=0, |
|
|
bias=False, |
|
|
) |
|
|
for i in range(upsample_layers) |
|
|
] |
|
|
|
|
|
return nn.Sequential(*blocks) |
|
|
|
|
|
self.upsample_latent0 = _create_project_upsample_block( |
|
|
dim_in=base_embed_dim, |
|
|
dim_out=self.dims_encoder[0], |
|
|
upsample_layers=3, |
|
|
dim_intermediate=self.dims_encoder[1], |
|
|
) |
|
|
self.upsample_latent1 = _create_project_upsample_block( |
|
|
dim_in=base_embed_dim, dim_out=self.dims_encoder[1], upsample_layers=2 |
|
|
) |
|
|
|
|
|
self.upsample0 = _create_project_upsample_block( |
|
|
dim_in=base_embed_dim, dim_out=self.dims_encoder[2], upsample_layers=1 |
|
|
) |
|
|
self.upsample1 = _create_project_upsample_block( |
|
|
dim_in=base_embed_dim, dim_out=self.dims_encoder[3], upsample_layers=1 |
|
|
) |
|
|
self.upsample2 = _create_project_upsample_block( |
|
|
dim_in=base_embed_dim, dim_out=self.dims_encoder[4], upsample_layers=1 |
|
|
) |
|
|
|
|
|
self.upsample_lowres = nn.ConvTranspose2d( |
|
|
in_channels=lowres_embed_dim, |
|
|
out_channels=self.dims_encoder[4], |
|
|
kernel_size=2, |
|
|
stride=2, |
|
|
padding=0, |
|
|
bias=True, |
|
|
) |
|
|
self.fuse_lowres = nn.Conv2d( |
|
|
in_channels=(self.dims_encoder[4] + self.dims_encoder[4]), |
|
|
out_channels=self.dims_encoder[4], |
|
|
kernel_size=1, |
|
|
stride=1, |
|
|
padding=0, |
|
|
bias=True, |
|
|
) |
|
|
|
|
|
def internal_resolution(self) -> int: |
|
|
"""Return the full image size of the SPN network.""" |
|
|
return self.patch_size * 4 |
|
|
|
|
|
@torch.jit.ignore |
|
|
def set_grad_checkpointing(self, is_enabled=True): |
|
|
"""Enable grad checkpointing.""" |
|
|
self.grad_checkpointing = is_enabled |
|
|
self.patch_encoder.set_grad_checkpointing(is_enabled) |
|
|
self.image_encoder.set_grad_checkpointing(is_enabled) |
|
|
|
|
|
@torch.jit.ignore |
|
|
def set_requires_grad_(self, patch_encoder: bool, image_encoder: bool): |
|
|
"""Set requires grad for separate components.""" |
|
|
self.patch_encoder.requires_grad_(patch_encoder) |
|
|
self.image_encoder.requires_grad_(image_encoder) |
|
|
|
|
|
|
|
|
|
|
|
self.patch_encoder.head.requires_grad_(False) |
|
|
self.image_encoder.head.requires_grad_(False) |
|
|
|
|
|
|
|
|
self.upsample_latent0.requires_grad_(patch_encoder) |
|
|
self.upsample_latent1.requires_grad_(patch_encoder) |
|
|
self.upsample0.requires_grad_(patch_encoder) |
|
|
self.upsample1.requires_grad_(patch_encoder) |
|
|
self.upsample2.requires_grad_(patch_encoder) |
|
|
|
|
|
|
|
|
self.upsample_lowres.requires_grad_(image_encoder) |
|
|
|
|
|
|
|
|
self.fuse_lowres.requires_grad_(image_encoder or patch_encoder) |
|
|
|
|
|
def _create_pyramid(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
|
|
"""Creates a 3-level image pyramid.""" |
|
|
|
|
|
x0 = x |
|
|
|
|
|
|
|
|
x1 = F.interpolate(x, size=None, scale_factor=0.5, mode="bilinear", align_corners=False) |
|
|
|
|
|
|
|
|
x2 = F.interpolate(x, size=None, scale_factor=0.25, mode="bilinear", align_corners=False) |
|
|
|
|
|
return x0, x1, x2 |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> list[torch.Tensor]: |
|
|
"""Encode input at multiple resolutions.""" |
|
|
batch_size = x.shape[0] |
|
|
|
|
|
|
|
|
x0, x1, x2 = self._create_pyramid(x) |
|
|
|
|
|
if self.use_patch_overlap: |
|
|
|
|
|
|
|
|
|
|
|
x0_patches = split(x0, overlap_ratio=0.25, patch_size=self.patch_size) |
|
|
|
|
|
x1_patches = split(x1, overlap_ratio=0.5, patch_size=self.patch_size) |
|
|
|
|
|
x2_patches = x2 |
|
|
padding = 3 |
|
|
else: |
|
|
|
|
|
|
|
|
|
|
|
x0_patches = split(x0, overlap_ratio=0.0, patch_size=self.patch_size) |
|
|
|
|
|
x1_patches = split(x1, overlap_ratio=0.0, patch_size=self.patch_size) |
|
|
|
|
|
x2_patches = x2 |
|
|
padding = 0 |
|
|
x0_tile_size = x0_patches.shape[0] |
|
|
|
|
|
|
|
|
|
|
|
x_pyramid_patches = torch.cat( |
|
|
(x0_patches, x1_patches, x2_patches), |
|
|
dim=0, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
x_pyramid_encodings, patch_intermediate_features = self.patch_encoder(x_pyramid_patches) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
x_latent0_encodings = self.patch_encoder.reshape_feature( |
|
|
patch_intermediate_features[self.patch_intermediate_features_ids[0]] |
|
|
) |
|
|
x_latent0_features = merge( |
|
|
x_latent0_encodings[: batch_size * x0_tile_size], |
|
|
batch_size=batch_size, |
|
|
padding=padding, |
|
|
) |
|
|
|
|
|
x_latent1_encodings = self.patch_encoder.reshape_feature( |
|
|
patch_intermediate_features[self.patch_intermediate_features_ids[1]] |
|
|
) |
|
|
x_latent1_features = merge( |
|
|
x_latent1_encodings[: batch_size * x0_tile_size], |
|
|
batch_size=batch_size, |
|
|
padding=padding, |
|
|
) |
|
|
|
|
|
|
|
|
x0_encodings, x1_encodings, x2_encodings = torch.split( |
|
|
x_pyramid_encodings, |
|
|
[len(x0_patches), len(x1_patches), len(x2_patches)], |
|
|
dim=0, |
|
|
) |
|
|
|
|
|
|
|
|
x0_features = merge(x0_encodings, batch_size=batch_size, padding=padding) |
|
|
|
|
|
|
|
|
x1_features = merge(x1_encodings, batch_size=batch_size, padding=2 * padding) |
|
|
|
|
|
|
|
|
x2_features = x2_encodings |
|
|
|
|
|
|
|
|
x_lowres_features, image_intermediate_features = self.image_encoder(x2_patches) |
|
|
|
|
|
|
|
|
x_latent0_features = checkpoint_wrapper(self, self.upsample_latent0, x_latent0_features) |
|
|
x_latent1_features = checkpoint_wrapper(self, self.upsample_latent1, x_latent1_features) |
|
|
|
|
|
x0_features = checkpoint_wrapper(self, self.upsample0, x0_features) |
|
|
x1_features = checkpoint_wrapper(self, self.upsample1, x1_features) |
|
|
x2_features = checkpoint_wrapper(self, self.upsample2, x2_features) |
|
|
|
|
|
x_lowres_features = checkpoint_wrapper(self, self.upsample_lowres, x_lowres_features) |
|
|
x_lowres_features = checkpoint_wrapper( |
|
|
self, self.fuse_lowres, torch.cat((x2_features, x_lowres_features), dim=1) |
|
|
) |
|
|
|
|
|
output = [ |
|
|
x_latent0_features, |
|
|
x_latent1_features, |
|
|
x0_features, |
|
|
x1_features, |
|
|
x_lowres_features, |
|
|
] |
|
|
|
|
|
return output |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@torch.fx.wrap |
|
|
def split(image: torch.Tensor, overlap_ratio: float = 0.25, patch_size: int = 384) -> torch.Tensor: |
|
|
"""Split the input into small patches with sliding window.""" |
|
|
patch_stride = int(patch_size * (1 - overlap_ratio)) |
|
|
|
|
|
image_size = image.shape[-1] |
|
|
steps = int(math.ceil((image_size - patch_size) / patch_stride)) + 1 |
|
|
|
|
|
x_patch_list = [] |
|
|
for j in range(steps): |
|
|
j0 = j * patch_stride |
|
|
j1 = j0 + patch_size |
|
|
|
|
|
for i in range(steps): |
|
|
i0 = i * patch_stride |
|
|
i1 = i0 + patch_size |
|
|
x_patch_list.append(image[..., j0:j1, i0:i1]) |
|
|
|
|
|
return torch.cat(x_patch_list, dim=0) |
|
|
|
|
|
|
|
|
|
|
|
@torch.fx.wrap |
|
|
def merge(image_patches: torch.Tensor, batch_size: int, padding: int = 3) -> torch.Tensor: |
|
|
"""Merge the patched input into a image with sliding window.""" |
|
|
steps = int(math.sqrt(image_patches.shape[0] // batch_size)) |
|
|
|
|
|
idx = 0 |
|
|
|
|
|
output_list = [] |
|
|
for j in range(steps): |
|
|
output_row_list = [] |
|
|
for i in range(steps): |
|
|
output = image_patches[batch_size * idx : batch_size * (idx + 1)] |
|
|
|
|
|
if padding != 0: |
|
|
if j != 0: |
|
|
output = output[..., padding:, :] |
|
|
if i != 0: |
|
|
output = output[..., :, padding:] |
|
|
if j != steps - 1: |
|
|
output = output[..., :-padding, :] |
|
|
if i != steps - 1: |
|
|
output = output[..., :, :-padding] |
|
|
|
|
|
output_row_list.append(output) |
|
|
idx += 1 |
|
|
|
|
|
output_row = torch.cat(output_row_list, dim=-1) |
|
|
output_list.append(output_row) |
|
|
output = torch.cat(output_list, dim=-2) |
|
|
return output |
|
|
|