|
|
|
|
|
|
| from __future__ import annotations
|
|
|
| import math
|
| from typing import Iterable, Optional
|
|
|
| import torch
|
| import torch.nn as nn
|
| import torch.nn.functional as F
|
|
|
|
|
| class DepthProEncoder(nn.Module):
|
| """DepthPro Encoder.
|
|
|
| An encoder aimed at creating multi-resolution encodings from Vision Transformers.
|
| """
|
|
|
| def __init__(
|
| self,
|
| dims_encoder: Iterable[int],
|
| patch_encoder: nn.Module,
|
| image_encoder: nn.Module,
|
| hook_block_ids: Iterable[int],
|
| decoder_features: int,
|
| ):
|
| """Initialize DepthProEncoder.
|
|
|
| The framework
|
| 1. creates an image pyramid,
|
| 2. generates overlapping patches with a sliding window at each pyramid level,
|
| 3. creates batched encodings via vision transformer backbones,
|
| 4. produces multi-resolution encodings.
|
|
|
| Args:
|
| ----
|
| img_size: Backbone image resolution.
|
| dims_encoder: Dimensions of the encoder at different layers.
|
| patch_encoder: Backbone used for patches.
|
| image_encoder: Backbone used for global image encoder.
|
| hook_block_ids: Hooks to obtain intermediate features for the patch encoder model.
|
| decoder_features: Number of feature output in the decoder.
|
|
|
| """
|
| super().__init__()
|
|
|
| self.dims_encoder = list(dims_encoder)
|
| self.patch_encoder = patch_encoder
|
| self.image_encoder = image_encoder
|
| self.hook_block_ids = list(hook_block_ids)
|
|
|
| patch_encoder_embed_dim = patch_encoder.embed_dim
|
| image_encoder_embed_dim = image_encoder.embed_dim
|
|
|
| self.out_size = int(
|
| patch_encoder.patch_embed.img_size[0] // patch_encoder.patch_embed.patch_size[0]
|
| )
|
|
|
| def _create_project_upsample_block(
|
| dim_in: int,
|
| dim_out: int,
|
| upsample_layers: int,
|
| dim_int: Optional[int] = None,
|
| ) -> nn.Module:
|
| if dim_int is None:
|
| dim_int = dim_out
|
|
|
| blocks = [
|
| nn.Conv2d(
|
| in_channels=dim_in,
|
| out_channels=dim_int,
|
| kernel_size=1,
|
| stride=1,
|
| padding=0,
|
| bias=False,
|
| )
|
| ]
|
|
|
|
|
| blocks += [
|
| nn.ConvTranspose2d(
|
| in_channels=dim_int if i == 0 else dim_out,
|
| out_channels=dim_out,
|
| kernel_size=2,
|
| stride=2,
|
| padding=0,
|
| bias=False,
|
| )
|
| for i in range(upsample_layers)
|
| ]
|
|
|
| return nn.Sequential(*blocks)
|
|
|
| self.upsample_latent0 = _create_project_upsample_block(
|
| dim_in=patch_encoder_embed_dim,
|
| dim_int=self.dims_encoder[0],
|
| dim_out=decoder_features,
|
| upsample_layers=3,
|
| )
|
| self.upsample_latent1 = _create_project_upsample_block(
|
| dim_in=patch_encoder_embed_dim, dim_out=self.dims_encoder[0], upsample_layers=2
|
| )
|
|
|
| self.upsample0 = _create_project_upsample_block(
|
| dim_in=patch_encoder_embed_dim, dim_out=self.dims_encoder[1], upsample_layers=1
|
| )
|
| self.upsample1 = _create_project_upsample_block(
|
| dim_in=patch_encoder_embed_dim, dim_out=self.dims_encoder[2], upsample_layers=1
|
| )
|
| self.upsample2 = _create_project_upsample_block(
|
| dim_in=patch_encoder_embed_dim, dim_out=self.dims_encoder[3], upsample_layers=1
|
| )
|
|
|
| self.upsample_lowres = nn.ConvTranspose2d(
|
| in_channels=image_encoder_embed_dim,
|
| out_channels=self.dims_encoder[3],
|
| kernel_size=2,
|
| stride=2,
|
| padding=0,
|
| bias=True,
|
| )
|
| self.fuse_lowres = nn.Conv2d(
|
| in_channels=(self.dims_encoder[3] + self.dims_encoder[3]),
|
| out_channels=self.dims_encoder[3],
|
| kernel_size=1,
|
| stride=1,
|
| padding=0,
|
| bias=True,
|
| )
|
|
|
|
|
| self.patch_encoder.blocks[self.hook_block_ids[0]].register_forward_hook(
|
| self._hook0
|
| )
|
| self.patch_encoder.blocks[self.hook_block_ids[1]].register_forward_hook(
|
| self._hook1
|
| )
|
|
|
| def _hook0(self, model, input, output):
|
| self.backbone_highres_hook0 = output
|
|
|
| def _hook1(self, model, input, output):
|
| self.backbone_highres_hook1 = output
|
|
|
| @property
|
| def img_size(self) -> int:
|
| """Return the full image size of the SPN network."""
|
| return self.patch_encoder.patch_embed.img_size[0] * 4
|
|
|
| def _create_pyramid(
|
| self, x: torch.Tensor
|
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| """Create a 3-level image pyramid."""
|
|
|
| x0 = x
|
|
|
|
|
| x1 = F.interpolate(
|
| x, size=None, scale_factor=0.5, mode="bilinear", align_corners=False
|
| )
|
|
|
|
|
| x2 = F.interpolate(
|
| x, size=None, scale_factor=0.25, mode="bilinear", align_corners=False
|
| )
|
|
|
| return x0, x1, x2
|
|
|
| def split(self, x: torch.Tensor, overlap_ratio: float = 0.25) -> torch.Tensor:
|
| """Split the input into small patches with sliding window."""
|
| patch_size = 384
|
| patch_stride = int(patch_size * (1 - overlap_ratio))
|
|
|
| image_size = x.shape[-1]
|
| steps = int(math.ceil((image_size - patch_size) / patch_stride)) + 1
|
|
|
| x_patch_list = []
|
| for j in range(steps):
|
| j0 = j * patch_stride
|
| j1 = j0 + patch_size
|
|
|
| for i in range(steps):
|
| i0 = i * patch_stride
|
| i1 = i0 + patch_size
|
| x_patch_list.append(x[..., j0:j1, i0:i1])
|
|
|
| return torch.cat(x_patch_list, dim=0)
|
|
|
| def merge(self, x: torch.Tensor, batch_size: int, padding: int = 3) -> torch.Tensor:
|
| """Merge the patched input into a image with sliding window."""
|
| steps = int(math.sqrt(x.shape[0] // batch_size))
|
|
|
| idx = 0
|
|
|
| output_list = []
|
| for j in range(steps):
|
| output_row_list = []
|
| for i in range(steps):
|
| output = x[batch_size * idx : batch_size * (idx + 1)]
|
|
|
| if j != 0:
|
| output = output[..., padding:, :]
|
| if i != 0:
|
| output = output[..., :, padding:]
|
| if j != steps - 1:
|
| output = output[..., :-padding, :]
|
| if i != steps - 1:
|
| output = output[..., :, :-padding]
|
|
|
| output_row_list.append(output)
|
| idx += 1
|
|
|
| output_row = torch.cat(output_row_list, dim=-1)
|
| output_list.append(output_row)
|
| output = torch.cat(output_list, dim=-2)
|
| return output
|
|
|
| def reshape_feature(
|
| self, embeddings: torch.Tensor, width, height, cls_token_offset=1
|
| ):
|
| """Discard class token and reshape 1D feature map to a 2D grid."""
|
| b, hw, c = embeddings.shape
|
|
|
|
|
| if cls_token_offset > 0:
|
| embeddings = embeddings[:, cls_token_offset:, :]
|
|
|
|
|
| embeddings = embeddings.reshape(b, height, width, c).permute(0, 3, 1, 2)
|
| return embeddings
|
|
|
| def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
|
| """Encode input at multiple resolutions.
|
|
|
| Args:
|
| ----
|
| x (torch.Tensor): Input image.
|
|
|
| Returns:
|
| -------
|
| Multi resolution encoded features.
|
|
|
| """
|
| batch_size = x.shape[0]
|
|
|
|
|
| x0, x1, x2 = self._create_pyramid(x)
|
|
|
|
|
|
|
|
|
| x0_patches = self.split(x0, overlap_ratio=0.25)
|
|
|
| x1_patches = self.split(x1, overlap_ratio=0.5)
|
|
|
| x2_patches = x2
|
|
|
|
|
| x_pyramid_patches = torch.cat(
|
| (x0_patches, x1_patches, x2_patches),
|
| dim=0,
|
| )
|
|
|
|
|
| x_pyramid_encodings = self.patch_encoder(x_pyramid_patches)
|
| x_pyramid_encodings = self.reshape_feature(
|
| x_pyramid_encodings, self.out_size, self.out_size
|
| )
|
|
|
|
|
|
|
| x_latent0_encodings = self.reshape_feature(
|
| self.backbone_highres_hook0,
|
| self.out_size,
|
| self.out_size,
|
| )
|
| x_latent0_features = self.merge(
|
| x_latent0_encodings[: batch_size * 5 * 5], batch_size=batch_size, padding=3
|
| )
|
|
|
| x_latent1_encodings = self.reshape_feature(
|
| self.backbone_highres_hook1,
|
| self.out_size,
|
| self.out_size,
|
| )
|
| x_latent1_features = self.merge(
|
| x_latent1_encodings[: batch_size * 5 * 5], batch_size=batch_size, padding=3
|
| )
|
|
|
|
|
| x0_encodings, x1_encodings, x2_encodings = torch.split(
|
| x_pyramid_encodings,
|
| [len(x0_patches), len(x1_patches), len(x2_patches)],
|
| dim=0,
|
| )
|
|
|
|
|
| x0_features = self.merge(x0_encodings, batch_size=batch_size, padding=3)
|
|
|
|
|
| x1_features = self.merge(x1_encodings, batch_size=batch_size, padding=6)
|
|
|
|
|
| x2_features = x2_encodings
|
|
|
|
|
| x_global_features = self.image_encoder(x2_patches)
|
| x_global_features = self.reshape_feature(
|
| x_global_features, self.out_size, self.out_size
|
| )
|
|
|
|
|
| x_latent0_features = self.upsample_latent0(x_latent0_features)
|
| x_latent1_features = self.upsample_latent1(x_latent1_features)
|
|
|
| x0_features = self.upsample0(x0_features)
|
| x1_features = self.upsample1(x1_features)
|
| x2_features = self.upsample2(x2_features)
|
|
|
| x_global_features = self.upsample_lowres(x_global_features)
|
| x_global_features = self.fuse_lowres(
|
| torch.cat((x2_features, x_global_features), dim=1)
|
| )
|
|
|
| return [
|
| x_latent0_features,
|
| x_latent1_features,
|
| x0_features,
|
| x1_features,
|
| x_global_features,
|
| ]
|
|
|