# Copyright (C) 2026 Li Auto Inc. All Rights Reserved.
"""MetricAnything DepthMap encoder."""
from __future__ import annotations
import math
from typing import Iterable
import torch
import torch.nn as nn
import torch.nn.functional as F
class MetricAnythingEncoder(nn.Module):
"""Multi-resolution encoder using a ViT patch backbone."""
def __init__(
self,
dims_encoder: Iterable[int],
patch_encoder: nn.Module,
hook_block_ids: Iterable[int],
) -> None:
super().__init__()
self.dims_encoder = list(dims_encoder)
self.patch_encoder = patch_encoder
self.hook_block_ids = list(hook_block_ids)
embed_dim = patch_encoder.embed_dim
patch_size = patch_encoder.patch_embed.patch_size[0]
self.out_size = int(patch_encoder.patch_embed.img_size[0] // patch_size)
self.upsample_latent0 = self._project_upsample(embed_dim, self.dims_encoder[0], upsample_layers=1)
self.upsample_latent1 = self._project_upsample(embed_dim, self.dims_encoder[0], upsample_layers=2)
self.upsample_latent2 = self._project_upsample(embed_dim, self.dims_encoder[0], upsample_layers=2)
self.upsample_latent3 = self._project_upsample(embed_dim, self.dims_encoder[0], upsample_layers=2)
self.upsample0 = self._project_upsample(embed_dim, self.dims_encoder[1], upsample_layers=3)
self.upsample1 = self._project_upsample(embed_dim, self.dims_encoder[2], upsample_layers=1)
self.upsample2 = self._project_upsample(embed_dim, self.dims_encoder[3], upsample_layers=1)
self.upsample_lowres = nn.ConvTranspose2d(
in_channels=embed_dim,
out_channels=self.dims_encoder[3],
kernel_size=2,
stride=2,
padding=0,
bias=True,
)
self.fuse_lowres = nn.Conv2d(
in_channels=self.dims_encoder[3] * 2,
out_channels=self.dims_encoder[3],
kernel_size=1,
stride=1,
padding=0,
bias=True,
)
self.patch_encoder.blocks[self.hook_block_ids[0]].register_forward_hook(self._hook0)
self.patch_encoder.blocks[self.hook_block_ids[1]].register_forward_hook(self._hook1)
self.patch_encoder.blocks[self.hook_block_ids[2]].register_forward_hook(self._hook2)
self.patch_encoder.blocks[self.hook_block_ids[3]].register_forward_hook(self._hook3)
@staticmethod
def _project_upsample(
dim_in: int,
dim_out: int,
upsample_layers: int,
dim_int: int | None = None,
) -> nn.Sequential:
if dim_int is None:
dim_int = dim_out
layers: list[nn.Module] = [
nn.Conv2d(
in_channels=dim_in,
out_channels=dim_int,
kernel_size=1,
stride=1,
padding=0,
bias=False,
)
]
layers += [
nn.ConvTranspose2d(
in_channels=dim_int if i == 0 else dim_out,
out_channels=dim_out,
kernel_size=2,
stride=2,
padding=0,
bias=False,
)
for i in range(upsample_layers)
]
return nn.Sequential(*layers)
def _hook0(self, _module, _input, output) -> None:
self.backbone_highres_hook0 = output
def _hook1(self, _module, _input, output) -> None:
self.backbone_highres_hook1 = output
def _hook2(self, _module, _input, output) -> None:
self.backbone_highres_hook2 = output
def _hook3(self, _module, _input, output) -> None:
self.backbone_highres_hook3 = output
@property
def img_size(self) -> int:
"""Network input resolution (typically 1536)."""
return self.patch_encoder.patch_embed.img_size[0] * 4
def _create_pyramid(
self, x: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Create a 3-level image pyramid."""
x0 = x
x1 = F.interpolate(x, size=None, scale_factor=0.5, mode="bilinear", align_corners=False)
x2 = F.interpolate(x, size=None, scale_factor=0.25, mode="bilinear", align_corners=False)
return x0, x1, x2
def split(self, x: torch.Tensor, overlap_ratio: float = 0.25) -> torch.Tensor:
"""Split the input into overlapped 384x384 patches."""
patch_size = 384
patch_stride = int(patch_size * (1 - overlap_ratio))
image_size = x.shape[-1]
steps = int(math.ceil((image_size - patch_size) / patch_stride)) + 1
patches = []
for j in range(steps):
j0 = j * patch_stride
j1 = j0 + patch_size
for i in range(steps):
i0 = i * patch_stride
i1 = i0 + patch_size
patches.append(x[..., j0:j1, i0:i1])
return torch.cat(patches, dim=0)
def merge(self, x: torch.Tensor, batch_size: int, padding: int = 3) -> torch.Tensor:
"""Merge overlapped patches back to a full feature map."""
steps = int(math.sqrt(x.shape[0] // batch_size))
idx = 0
rows = []
for j in range(steps):
cols = []
for i in range(steps):
patch = x[batch_size * idx : batch_size * (idx + 1)]
if j != 0:
patch = patch[..., padding:, :]
if i != 0:
patch = patch[..., :, padding:]
if j != steps - 1:
patch = patch[..., :-padding, :]
if i != steps - 1:
patch = patch[..., :, :-padding]
cols.append(patch)
idx += 1
rows.append(torch.cat(cols, dim=-1))
return torch.cat(rows, dim=-2)
@staticmethod
def reshape_feature(embeddings: torch.Tensor, width: int, height: int, cls_token_offset: int = 1) -> torch.Tensor:
"""Discard class token and reshape 1D tokens to a 2D feature map."""
batch, tokens, channels = embeddings.shape
if cls_token_offset > 0:
embeddings = embeddings[:, cls_token_offset:, :]
return embeddings.reshape(batch, height, width, channels).permute(0, 3, 1, 2)
def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
"""Encode input at multiple resolutions."""
batch_size = x.shape[0]
x0, x1, x2 = self._create_pyramid(x)
x0_patches = self.split(x0, overlap_ratio=0.25)
x1_patches = self.split(x1, overlap_ratio=0.5)
x2_patches = x2
x_pyramid_patches = torch.cat((x0_patches, x1_patches, x2_patches), dim=0)
x_pyramid_encodings = self.patch_encoder(x_pyramid_patches)
x_pyramid_encodings = self.reshape_feature(
x_pyramid_encodings, self.out_size, self.out_size, cls_token_offset=1
)
if isinstance(self.backbone_highres_hook0, list):
self.backbone_highres_hook0 = self.backbone_highres_hook0[0]
self.backbone_highres_hook1 = self.backbone_highres_hook1[0]
self.backbone_highres_hook2 = self.backbone_highres_hook2[0]
self.backbone_highres_hook3 = self.backbone_highres_hook3[0]
high_patch_count = x0_patches.shape[0]
x_latent0_features = self.merge(
self.reshape_feature(
self.backbone_highres_hook0,
self.out_size,
self.out_size,
cls_token_offset=5,
)[:high_patch_count],
batch_size=batch_size,
padding=3,
)
x_latent1_features = self.merge(
self.reshape_feature(
self.backbone_highres_hook1,
self.out_size,
self.out_size,
cls_token_offset=5,
)[:high_patch_count],
batch_size=batch_size,
padding=3,
)
x_latent2_features = self.merge(
self.reshape_feature(
self.backbone_highres_hook2,
self.out_size,
self.out_size,
cls_token_offset=5,
)[:high_patch_count],
batch_size=batch_size,
padding=3,
)
x_latent3_features = self.merge(
self.reshape_feature(
self.backbone_highres_hook3,
self.out_size,
self.out_size,
cls_token_offset=5,
)[:high_patch_count],
batch_size=batch_size,
padding=3,
)
x0_encodings, x1_encodings, x2_encodings = torch.split(
x_pyramid_encodings,
[len(x0_patches), len(x1_patches), len(x2_patches)],
dim=0,
)
x0_features = self.merge(x0_encodings, batch_size=batch_size, padding=3)
x1_features = self.merge(x1_encodings, batch_size=batch_size, padding=6)
x2_features = x2_encodings
x_global_features = x2_features.clone()
x_latent0_features = self.upsample_latent0(x_latent0_features)
x_latent1_features = self.upsample_latent1(x_latent1_features)
x_latent2_features = self.upsample_latent2(x_latent2_features)
x_latent3_features = self.upsample_latent3(x_latent3_features)
x0_features = self.upsample0(x0_features)
x1_features = self.upsample1(x1_features)
x2_features = self.upsample2(x2_features)
x_global_features = self.upsample_lowres(x_global_features)
x_global_features = self.fuse_lowres(torch.cat((x2_features, x_global_features), dim=1))
return [
x_latent0_features,
x_latent1_features,
x_latent2_features,
x_latent3_features,
x0_features,
x1_features,
x_global_features,
]
__all__ = ["MetricAnythingEncoder"]