"""Contains Sliding Pyramid Network architecture. For licensing see accompanying LICENSE file. Copyright (C) 2025 Apple Inc. All Rights Reserved. """ from __future__ import annotations import math from typing import Iterable import torch import torch.fx import torch.nn as nn import torch.nn.functional as F from sharp.utils.training import checkpoint_wrapper from .base_encoder import BaseEncoder from .vit_encoder import TimmViT # torch.fx.wrap is used here to mark functions as leaf nodes during symbolic tracing # ensuring they are not traced but seen as atomic operation. In short, symbolic tracing # struggles with native python functions and conditional flows. non_traceable_ops = ("len", "int") for op in non_traceable_ops: torch.fx.wrap(op) class SlidingPyramidNetwork(BaseEncoder): """Sliding Pyramid Network. An encoder aimed at creating multi-resolution encodings from Vision Transformers. Reference: Bochkovskii et al. - "Depth pro: Sharp monocular metric depth in less than a second." (ICLR 2024) """ def __init__( self, dims_encoder: Iterable[int], patch_encoder: TimmViT, image_encoder: TimmViT, use_patch_overlap: bool = True, ): """Initialize Sliding Pyramid Network. The framework 1. creates an image pyramid, 2. generates overlapping patches with a sliding window at each pyramid level, 3. creates batched encodings via vision transformer backbones, 4. produces multi-resolution encodings. Args: dims_encoder: Dimensions of the encoder at different layers. patch_encoder: Backbone used for highres part of the pyramid. image_encoder: Backbone used for lowres part of the pyramid. use_patch_overlap: Whether to use overlap between patches in SPN. """ super().__init__() self.dim_in = patch_encoder.dim_in self.dims_encoder = list(dims_encoder) self.patch_encoder = patch_encoder self.image_encoder = image_encoder base_embed_dim = patch_encoder.embed_dim lowres_embed_dim = image_encoder.embed_dim self.patch_size = patch_encoder.internal_resolution() self.grad_checkpointing = False self.use_patch_overlap = use_patch_overlap # Retrieve intermediate feature ids registered in create_monodepth_encoder. self.patch_intermediate_features_ids = patch_encoder.intermediate_features_ids if ( not isinstance(self.patch_intermediate_features_ids, list) or not len(self.patch_intermediate_features_ids) == 4 ): raise ValueError("Patch intermediate feature ids must be a 4-item list.") self.image_intermediate_features_ids = image_encoder.intermediate_features_ids def _create_project_upsample_block( dim_in: int, dim_out: int, upsample_layers: int, dim_intermediate=None, ) -> nn.Module: if dim_intermediate is None: dim_intermediate = dim_out # Projection. blocks = [ nn.Conv2d( in_channels=dim_in, out_channels=dim_intermediate, kernel_size=1, stride=1, padding=0, bias=False, ) ] # Upsampling. blocks += [ nn.ConvTranspose2d( in_channels=dim_intermediate if i == 0 else dim_out, out_channels=dim_out, kernel_size=2, stride=2, padding=0, bias=False, ) for i in range(upsample_layers) ] return nn.Sequential(*blocks) self.upsample_latent0 = _create_project_upsample_block( dim_in=base_embed_dim, dim_out=self.dims_encoder[0], upsample_layers=3, dim_intermediate=self.dims_encoder[1], ) self.upsample_latent1 = _create_project_upsample_block( dim_in=base_embed_dim, dim_out=self.dims_encoder[1], upsample_layers=2 ) self.upsample0 = _create_project_upsample_block( dim_in=base_embed_dim, dim_out=self.dims_encoder[2], upsample_layers=1 ) self.upsample1 = _create_project_upsample_block( dim_in=base_embed_dim, dim_out=self.dims_encoder[3], upsample_layers=1 ) self.upsample2 = _create_project_upsample_block( dim_in=base_embed_dim, dim_out=self.dims_encoder[4], upsample_layers=1 ) self.upsample_lowres = nn.ConvTranspose2d( in_channels=lowres_embed_dim, out_channels=self.dims_encoder[4], kernel_size=2, stride=2, padding=0, bias=True, ) self.fuse_lowres = nn.Conv2d( in_channels=(self.dims_encoder[4] + self.dims_encoder[4]), out_channels=self.dims_encoder[4], kernel_size=1, stride=1, padding=0, bias=True, ) def internal_resolution(self) -> int: """Return the full image size of the SPN network.""" return self.patch_size * 4 @torch.jit.ignore def set_grad_checkpointing(self, is_enabled=True): """Enable grad checkpointing.""" self.grad_checkpointing = is_enabled self.patch_encoder.set_grad_checkpointing(is_enabled) self.image_encoder.set_grad_checkpointing(is_enabled) @torch.jit.ignore def set_requires_grad_(self, patch_encoder: bool, image_encoder: bool): """Set requires grad for separate components.""" self.patch_encoder.requires_grad_(patch_encoder) self.image_encoder.requires_grad_(image_encoder) # Always freeze the unused TimmViT head to exclude it from the calculation of # trainable parameters. self.patch_encoder.head.requires_grad_(False) self.image_encoder.head.requires_grad_(False) # These upsamplers only affect patch encoder's feature maps. self.upsample_latent0.requires_grad_(patch_encoder) self.upsample_latent1.requires_grad_(patch_encoder) self.upsample0.requires_grad_(patch_encoder) self.upsample1.requires_grad_(patch_encoder) self.upsample2.requires_grad_(patch_encoder) # This upsampler affects only image encoder's feature map. self.upsample_lowres.requires_grad_(image_encoder) # This fuser affects both image and patch encoders. self.fuse_lowres.requires_grad_(image_encoder or patch_encoder) def _create_pyramid(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Creates a 3-level image pyramid.""" # Original resolution: 1536 by default. x0 = x # Middle resolution: 768 by default. x1 = F.interpolate(x, size=None, scale_factor=0.5, mode="bilinear", align_corners=False) # Low resolution: 384 by default, corresponding to the backbone resolution. x2 = F.interpolate(x, size=None, scale_factor=0.25, mode="bilinear", align_corners=False) return x0, x1, x2 def forward(self, x: torch.Tensor) -> list[torch.Tensor]: """Encode input at multiple resolutions.""" batch_size = x.shape[0] # Step 0: create a 3-level image pyramid. x0, x1, x2 = self._create_pyramid(x) if self.use_patch_overlap: # Step 1: split to create batched overlapped mini-images at the ViT # resolution. # 5x5 @ 384x384 at the highest resolution (1536x1536). x0_patches = split(x0, overlap_ratio=0.25, patch_size=self.patch_size) # 3x3 @ 384x384 at the middle resolution (768x768). x1_patches = split(x1, overlap_ratio=0.5, patch_size=self.patch_size) # 1x1 # 384x384 at the lowest resolution (384x384). x2_patches = x2 padding = 3 else: # Step 1: split to create batched overlapped mini-images at the ViT # resolution. # 4x4 @ 384x384 at the highest resolution (1536x1536). x0_patches = split(x0, overlap_ratio=0.0, patch_size=self.patch_size) # 2x2 @ 384x384 at the middle resolution (768x768). x1_patches = split(x1, overlap_ratio=0.0, patch_size=self.patch_size) # 1x1 # 384x384 at the lowest resolution (384x384). x2_patches = x2 padding = 0 x0_tile_size = x0_patches.shape[0] # Concatenate all the sliding window patches and form a batch of size # (35=5x5+3x3+1x1) or (21=4x4+2x2+1x1). x_pyramid_patches = torch.cat( (x0_patches, x1_patches, x2_patches), dim=0, ) # Run the ViT model and get the result of large batch size. # # For the retrieval of intermediate features forward hooks are more concise, # but they are not well compatible with symbolic tracing because attributes # of submodules can be lost during tracing. Therefore, forward hooks may not # be preserved during graph transformation, leading to unexpected behavior. # To avoid such issues it is safer not to use them because they are not # essential here. x_pyramid_encodings, patch_intermediate_features = self.patch_encoder(x_pyramid_patches) # Step 3: merging. # Merge highres latent encoding. # NOTE: list type check has completed in init. x_latent0_encodings = self.patch_encoder.reshape_feature( patch_intermediate_features[self.patch_intermediate_features_ids[0]] # type:ignore[index] ) x_latent0_features = merge( x_latent0_encodings[: batch_size * x0_tile_size], batch_size=batch_size, padding=padding, ) x_latent1_encodings = self.patch_encoder.reshape_feature( patch_intermediate_features[self.patch_intermediate_features_ids[1]] # type:ignore[index] ) x_latent1_features = merge( x_latent1_encodings[: batch_size * x0_tile_size], batch_size=batch_size, padding=padding, ) # Split the 35 batch size from pyramid encoding back into 5x5+3x3+1x1. x0_encodings, x1_encodings, x2_encodings = torch.split( x_pyramid_encodings, [len(x0_patches), len(x1_patches), len(x2_patches)], dim=0, ) # 96x96 feature maps by merging 5x5 @ 24x24 patches with overlaps. x0_features = merge(x0_encodings, batch_size=batch_size, padding=padding) # 48x84 feature maps by merging 3x3 @ 24x24 patches with overlaps. x1_features = merge(x1_encodings, batch_size=batch_size, padding=2 * padding) # 24x24 feature maps. x2_features = x2_encodings # Apply the image encoder. x_lowres_features, image_intermediate_features = self.image_encoder(x2_patches) # Upsample feature maps. x_latent0_features = checkpoint_wrapper(self, self.upsample_latent0, x_latent0_features) x_latent1_features = checkpoint_wrapper(self, self.upsample_latent1, x_latent1_features) x0_features = checkpoint_wrapper(self, self.upsample0, x0_features) x1_features = checkpoint_wrapper(self, self.upsample1, x1_features) x2_features = checkpoint_wrapper(self, self.upsample2, x2_features) x_lowres_features = checkpoint_wrapper(self, self.upsample_lowres, x_lowres_features) x_lowres_features = checkpoint_wrapper( self, self.fuse_lowres, torch.cat((x2_features, x_lowres_features), dim=1) ) output = [ x_latent0_features, x_latent1_features, x0_features, x1_features, x_lowres_features, ] return output # It seems that torch.fx.wrap can only be applied to functions, not methods. # Hence, split and merge were converted into functions to be marked as atomic # operations for symbolic tracing. @torch.fx.wrap def split(image: torch.Tensor, overlap_ratio: float = 0.25, patch_size: int = 384) -> torch.Tensor: """Split the input into small patches with sliding window.""" patch_stride = int(patch_size * (1 - overlap_ratio)) image_size = image.shape[-1] steps = int(math.ceil((image_size - patch_size) / patch_stride)) + 1 x_patch_list = [] for j in range(steps): j0 = j * patch_stride j1 = j0 + patch_size for i in range(steps): i0 = i * patch_stride i1 = i0 + patch_size x_patch_list.append(image[..., j0:j1, i0:i1]) return torch.cat(x_patch_list, dim=0) # Decorator marking function as an atomic operator for symbolic tracing. @torch.fx.wrap def merge(image_patches: torch.Tensor, batch_size: int, padding: int = 3) -> torch.Tensor: """Merge the patched input into a image with sliding window.""" steps = int(math.sqrt(image_patches.shape[0] // batch_size)) idx = 0 output_list = [] for j in range(steps): output_row_list = [] for i in range(steps): output = image_patches[batch_size * idx : batch_size * (idx + 1)] if padding != 0: if j != 0: output = output[..., padding:, :] if i != 0: output = output[..., :, padding:] if j != steps - 1: output = output[..., :-padding, :] if i != steps - 1: output = output[..., :, :-padding] output_row_list.append(output) idx += 1 output_row = torch.cat(output_row_list, dim=-1) output_list.append(output_row) output = torch.cat(output_list, dim=-2) return output