# Copyright (C) 2026 Li Auto Inc. All Rights Reserved.
"""Decoder blocks for MetricAnything DepthMap."""
from __future__ import annotations
from typing import Iterable
import torch
from torch import nn
class MultiresConvDecoder(nn.Module):
"""Fuse multi-resolution encoder features."""
def __init__(self, dims_encoder: Iterable[int], dim_decoder: int) -> None:
super().__init__()
self.dims_encoder = list(dims_encoder)
self.dim_decoder = dim_decoder
num_encoders = len(self.dims_encoder)
in_dims = (
[self.dims_encoder[-3]]
+ [self.dims_encoder[-4]] * 4
+ [self.dims_encoder[-2], self.dims_encoder[-1]]
)
self.convs = nn.ModuleList(
[
nn.Conv2d(dim_in, dim_decoder, kernel_size=3, stride=1, padding=1, bias=False)
for dim_in in in_dims
]
)
deconv_flags = [False, True, False, False, True, True, True]
self.fusions = nn.ModuleList(
[
FeatureFusionBlock2d(
num_features=dim_decoder,
deconv=deconv_flags[i],
batch_norm=False,
disable_resnet1=(i == num_encoders - 1),
)
for i in range(num_encoders)
]
)
def forward(self, encodings: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""Decode the multi-resolution encodings."""
num_levels = len(encodings)
num_encoders = len(self.dims_encoder)
if num_levels != num_encoders:
raise ValueError(
f"Got encoder output levels={num_levels}, expected levels={num_encoders}."
)
encodings_forward_ids = [4, 3, 2, 1, 0, 5, 6]
features = self.convs[-1](encodings[-1])
lowres_features = features
features = self.fusions[-1](features, None)
for i in range(num_levels - 2, -1, -1):
features_i = self.convs[i](encodings[encodings_forward_ids[i]])
features = self.fusions[i](features, features_i)
return features, lowres_features
class ResidualBlock(nn.Module):
"""Generic residual block (He et al., 2016)."""
def __init__(self, residual: nn.Module, shortcut: nn.Module | None = None) -> None:
super().__init__()
self.residual = residual
self.shortcut = shortcut
def forward(self, x: torch.Tensor) -> torch.Tensor:
delta_x = self.residual(x)
if self.shortcut is not None:
x = self.shortcut(x)
return x + delta_x
class FeatureFusionBlock2d(nn.Module):
"""Feature fusion with residual refinement and optional upsampling."""
def __init__(
self,
num_features: int,
deconv: bool = False,
batch_norm: bool = False,
disable_resnet1: bool = False,
) -> None:
super().__init__()
self.resnet1 = nn.Identity() if disable_resnet1 else self._residual_block(num_features, batch_norm)
self.resnet2 = self._residual_block(num_features, batch_norm)
self.use_deconv = deconv
if deconv:
self.deconv = nn.ConvTranspose2d(
in_channels=num_features,
out_channels=num_features,
kernel_size=2,
stride=2,
padding=0,
bias=False,
)
self.out_conv = nn.Conv2d(
num_features,
num_features,
kernel_size=1,
stride=1,
padding=0,
bias=True,
)
self.skip_add = nn.quantized.FloatFunctional()
def forward(self, x0: torch.Tensor, x1: torch.Tensor | None = None) -> torch.Tensor:
x = x0
if x1 is not None:
x1_res = self.resnet1(x1)
x = self.skip_add.add(x, x1_res)
x = self.resnet2(x)
if self.use_deconv:
x = self.deconv(x)
return self.out_conv(x)
@staticmethod
def _residual_block(num_features: int, batch_norm: bool) -> ResidualBlock:
def _create_block(dim: int, batch_norm: bool) -> list[nn.Module]:
layers: list[nn.Module] = [
nn.ReLU(False),
nn.Conv2d(
dim,
dim,
kernel_size=3,
stride=1,
padding=1,
bias=not batch_norm,
),
]
if batch_norm:
layers.append(nn.BatchNorm2d(dim))
return layers
residual = nn.Sequential(
*_create_block(dim=num_features, batch_norm=batch_norm),
*_create_block(dim=num_features, batch_norm=batch_norm),
)
return ResidualBlock(residual)
__all__ = ["MultiresConvDecoder", "FeatureFusionBlock2d", "ResidualBlock"]