infinity1096
initial commit
c8b42eb
from typing import List, Tuple
import torch
import torch.nn as nn
from uniception.models.encoders import ViTEncoderInput
from uniception.models.encoders.croco import CroCoEncoder
from uniception.models.encoders.image_normalizations import IMAGE_NORMALIZATION_DICT
from uniception.models.info_sharing.base import MultiViewTransformerInput
from uniception.models.info_sharing.cross_attention_transformer import (
MultiViewCrossAttentionTransformer,
MultiViewCrossAttentionTransformerIFR,
)
from uniception.models.libs.croco.pos_embed import RoPE2D, get_2d_sincos_pos_embed
from uniception.models.prediction_heads.adaptors import PointMapWithConfidenceAdaptor
from uniception.models.prediction_heads.base import AdaptorInput, PredictionHeadInput, PredictionHeadLayeredInput
from uniception.models.prediction_heads.dpt import DPTFeature, DPTRegressionProcessor
from uniception.models.prediction_heads.linear import LinearFeature
def is_symmetrized(gt1, gt2):
"Function to check if input pairs are symmetrized, i.e., (a, b) and (b, a) always exist in the input"
x = gt1["instance"]
y = gt2["instance"]
if len(x) == len(y) and len(x) == 1:
return False # special case of batchsize 1
ok = True
for i in range(0, len(x), 2):
ok = ok and (x[i] == y[i + 1]) and (x[i + 1] == y[i])
return ok
def interleave(tensor1, tensor2):
"Interleave two tensors along the first dimension (used to avoid redundant encoding for symmetrized pairs)"
res1 = torch.stack((tensor1, tensor2), dim=1).flatten(0, 1)
res2 = torch.stack((tensor2, tensor1), dim=1).flatten(0, 1)
return res1, res2
class DUSt3R(nn.Module):
"DUSt3R defined with UniCeption Modules"
def __init__(
self,
name: str,
data_norm_type: str = "dust3r",
img_size: tuple = (224, 224),
patch_embed_cls: str = "PatchEmbedDust3R",
pred_head_type: str = "linear",
pred_head_output_dim: int = 4,
pred_head_feature_dim: int = 256,
depth_mode: Tuple[str, float, float] = ("exp", -float("inf"), float("inf")),
conf_mode: Tuple[str, float, float] = ("exp", 1, float("inf")),
pos_embed: str = "RoPE100",
pretrained_checkpoint_path: str = None,
pretrained_encoder_checkpoint_path: str = None,
pretrained_info_sharing_checkpoint_path: str = None,
pretrained_pred_head_checkpoint_paths: List[str] = [None, None],
pretrained_pred_head_regressor_checkpoint_paths: List[str] = [None, None],
override_encoder_checkpoint_attributes: bool = False,
*args,
**kwargs,
):
"""
Two-view model containing siamese encoders followed by a two-view cross-attention transformer and respective downstream heads.
The goal is to output scene representation directly, both images in view1's frame (hence the asymmetry).
Args:
name (str): Name of the model.
data_norm_type (str): Type of data normalization. (default: "dust3r")
img_size (tuple): Size of input images. (default: (224, 224))
patch_embed_cls (str): Class for patch embedding. (default: "PatchEmbedDust3R"). Options:
- "PatchEmbedDust3R"
- "ManyAR_PatchEmbed"
pred_head_type (str): Type of prediction head. (default: "linear"). Options:
- "linear"
- "dpt"
pred_head_output_dim (int): Output dimension of prediction head. (default: 4)
pred_head_feature_dim (int): Feature dimension of prediction head. (default: 256)
depth_mode (Tuple[str, float, float]): Depth mode settings (mode=['linear', 'square', 'exp'], vmin, vmax). (default: ('exp', -inf, inf))
conf_mode (Tuple[str, float, float]): Confidence mode settings (mode=['linear', 'square', 'exp'], vmin, vmax). (default: ('exp', 1, inf))
pos_embed (str): Position embedding type. (default: 'RoPE100')
landscape_only (bool): Run downstream head only in landscape orientation. (default: True)
pretrained_checkpoint_path (str): Path to pretrained checkpoint. (default: None)
pretrained_encoder_checkpoint_path (str): Path to pretrained encoder checkpoint. (default: None)
pretrained_info_sharing_checkpoint_path (str): Path to pretrained info_sharing checkpoint. (default: None)
pretrained_pred_head_checkpoint_paths (List[str]): Paths to pretrained prediction head checkpoints. (default: None)
pretrained_pred_head_regressor_checkpoint_paths (List[str]): Paths to pretrained prediction head regressor checkpoints. (default: None)
override_encoder_checkpoint_attributes (bool): Whether to override encoder checkpoint attributes. (default: False)
"""
super().__init__(*args, **kwargs)
# Initalize the attributes
self.name = name
self.data_norm_type = data_norm_type
self.img_size = img_size
self.patch_embed_cls = patch_embed_cls
self.pred_head_type = pred_head_type
self.pred_head_output_dim = pred_head_output_dim
self.depth_mode = depth_mode
self.conf_mode = conf_mode
self.pos_embed = pos_embed
self.pretrained_checkpoint_path = pretrained_checkpoint_path
self.pretrained_encoder_checkpoint_path = pretrained_encoder_checkpoint_path
self.pretrained_info_sharing_checkpoint_path = pretrained_info_sharing_checkpoint_path
self.pretrained_pred_head_checkpoint_paths = pretrained_pred_head_checkpoint_paths
self.pretrained_pred_head_regressor_checkpoint_paths = pretrained_pred_head_regressor_checkpoint_paths
self.override_encoder_checkpoint_attributes = override_encoder_checkpoint_attributes
# Initialize RoPE for the CroCo Encoder & Two-View Cross Attention Transformer
freq = float(pos_embed[len("RoPE") :])
self.rope = RoPE2D(freq=freq)
# Initialize Encoder
self.encoder = CroCoEncoder(
name=name,
data_norm_type=data_norm_type,
patch_embed_cls=patch_embed_cls,
img_size=img_size,
pretrained_checkpoint_path=pretrained_encoder_checkpoint_path,
override_checkpoint_attributes=override_encoder_checkpoint_attributes,
)
# Initialize Multi-View Cross Attention Transformer
if self.pred_head_type == "linear":
# Returns only normalized last layer features
self.info_sharing = MultiViewCrossAttentionTransformer(
name="base_info_sharing",
input_embed_dim=self.encoder.enc_embed_dim,
num_views=2,
custom_positional_encoding=self.rope,
pretrained_checkpoint_path=pretrained_info_sharing_checkpoint_path,
)
elif self.pred_head_type == "dpt":
# Returns intermediate features and normalized last layer features
self.info_sharing = MultiViewCrossAttentionTransformerIFR(
name="base_info_sharing",
input_embed_dim=self.encoder.enc_embed_dim,
num_views=2,
indices=[5, 8],
norm_intermediate=False,
custom_positional_encoding=self.rope,
pretrained_checkpoint_path=pretrained_info_sharing_checkpoint_path,
)
else:
raise ValueError(f"Invalid prediction head type: {pred_head_type}. Must be 'linear' or 'dpt'.")
# Initialize Prediction Heads
if pred_head_type == "linear":
# Initialize Prediction Head 1
self.head1 = LinearFeature(
input_feature_dim=self.info_sharing.dim,
output_dim=pred_head_output_dim,
patch_size=self.encoder.patch_size,
pretrained_checkpoint_path=pretrained_pred_head_checkpoint_paths[0],
)
# Initialize Prediction Head 2
self.head2 = LinearFeature(
input_feature_dim=self.info_sharing.dim,
output_dim=pred_head_output_dim,
patch_size=self.encoder.patch_size,
pretrained_checkpoint_path=pretrained_pred_head_checkpoint_paths[1],
)
elif pred_head_type == "dpt":
# Initialze Predction Head 1
self.dpt_feature_head1 = DPTFeature(
patch_size=self.encoder.patch_size,
hooks=[0, 1, 2, 3],
input_feature_dims=[self.encoder.enc_embed_dim] + [self.info_sharing.dim] * 3,
feature_dim=pred_head_feature_dim,
pretrained_checkpoint_path=pretrained_pred_head_checkpoint_paths[0],
)
self.dpt_regressor_head1 = DPTRegressionProcessor(
input_feature_dim=pred_head_feature_dim,
output_dim=pred_head_output_dim,
pretrained_checkpoint_path=pretrained_pred_head_regressor_checkpoint_paths[0],
)
self.head1 = nn.Sequential(self.dpt_feature_head1, self.dpt_regressor_head1)
# Initialize Prediction Head 2
self.dpt_feature_head2 = DPTFeature(
patch_size=self.encoder.patch_size,
hooks=[0, 1, 2, 3],
input_feature_dims=[self.encoder.enc_embed_dim] + [self.info_sharing.dim] * 3,
feature_dim=pred_head_feature_dim,
pretrained_checkpoint_path=pretrained_pred_head_checkpoint_paths[1],
)
self.dpt_regressor_head2 = DPTRegressionProcessor(
input_feature_dim=pred_head_feature_dim,
output_dim=pred_head_output_dim,
pretrained_checkpoint_path=pretrained_pred_head_regressor_checkpoint_paths[1],
)
self.head2 = nn.Sequential(self.dpt_feature_head2, self.dpt_regressor_head2)
# Initialize Final Output Adaptor
self.adaptor = PointMapWithConfidenceAdaptor(
name="pointmap",
pointmap_mode=depth_mode[0],
pointmap_vmin=depth_mode[1],
pointmap_vmax=depth_mode[2],
confidence_type=conf_mode[0],
confidence_vmin=conf_mode[1],
confidence_vmax=conf_mode[2],
)
# Load pretrained weights
if self.pretrained_checkpoint_path is not None:
print(f"Loading pretrained DUSt3R weights from {self.pretrained_checkpoint_path} ...")
ckpt = torch.load(self.pretrained_checkpoint_path, weights_only=False)
print(self.load_state_dict(ckpt["model"]))
def _encode_image_pairs(self, img1, img2, data_norm_type):
"Encode two different batches of images (each batch can have different image shape)"
if img1.shape[-2:] == img2.shape[-2:]:
encoder_input = ViTEncoderInput(image=torch.cat((img1, img2), dim=0), data_norm_type=data_norm_type)
encoder_output = self.encoder(encoder_input)
out, out2 = encoder_output.features.chunk(2, dim=0)
else:
encoder_input = ViTEncoderInput(image=img1, data_norm_type=data_norm_type)
out = self.encoder(encoder_input)
out = out.features
encoder_input2 = ViTEncoderInput(image=img2)
out2 = self.encoder(encoder_input2)
out2 = out2.features
return out, out2
def _encode_symmetrized(self, view1, view2):
"Encode image pairs accounting for symmetrization, i.e., (a, b) and (b, a) always exist in the input"
img1 = view1["img"]
img2 = view2["img"]
if is_symmetrized(view1, view2):
# Computing half of forward pass'
feat1, feat2 = self._encode_image_pairs(img1[::2], img2[::2], data_norm_type=view1["data_norm_type"])
feat1, feat2 = interleave(feat1, feat2)
else:
feat1, feat2 = self._encode_image_pairs(img1, img2, data_norm_type=view1["data_norm_type"])
return feat1, feat2
def _downstream_head(self, head_num, decout, img_shape):
"Run the respective prediction heads"
head = getattr(self, f"head{head_num}")
if self.pred_head_type == "linear":
head_input = PredictionHeadInput(last_feature=decout[f"{head_num}"])
elif self.pred_head_type == "dpt":
head_input = PredictionHeadLayeredInput(list_features=decout[f"{head_num}"], target_output_shape=img_shape)
return head(head_input)
def forward(self, view1, view2):
"""
Forward pass for DUSt3R performing the following operations:
1. Encodes the two input views (images).
2. Combines the encoded features using a two-view cross-attention transformer.
3. Passes the combined features through the respective prediction heads.
4. Returns the processed final outputs for both views.
Args:
view1 (dict): Dictionary containing the first view's images and instance information.
"img" is a required key and value is a tensor of shape (B, C, H, W).
view2 (dict): Dictionary containing the second view's images and instance information.
"img" is a required key and value is a tensor of shape (B, C, H, W).
Returns:
Tuple[dict, dict]: A tuple containing the final outputs for both views.
"""
# Get input shapes
_, _, height1, width1 = view1["img"].shape
_, _, height2, width2 = view2["img"].shape
shape1 = (int(height1), int(width1))
shape2 = (int(height2), int(width2))
# Encode the two images --> Each feat output: BCHW features (batch_size, feature_dim, feature_height, feature_width)
feat1, feat2 = self._encode_symmetrized(view1, view2)
# Combine all images into view-centric representation
info_sharing_input = MultiViewTransformerInput(features=[feat1, feat2])
if self.pred_head_type == "linear":
final_info_sharing_multi_view_feat = self.info_sharing(info_sharing_input)
elif self.pred_head_type == "dpt":
final_info_sharing_multi_view_feat, intermediate_info_sharing_multi_view_feat = self.info_sharing(
info_sharing_input
)
if self.pred_head_type == "linear":
# Define feature dictionary for linear head
info_sharing_outputs = {
"1": final_info_sharing_multi_view_feat.features[0].float(),
"2": final_info_sharing_multi_view_feat.features[1].float(),
}
elif self.pred_head_type == "dpt":
# Define feature dictionary for DPT head
info_sharing_outputs = {
"1": [
feat1.float(),
intermediate_info_sharing_multi_view_feat[0].features[0].float(),
intermediate_info_sharing_multi_view_feat[1].features[0].float(),
final_info_sharing_multi_view_feat.features[0].float(),
],
"2": [
feat2.float(),
intermediate_info_sharing_multi_view_feat[0].features[1].float(),
intermediate_info_sharing_multi_view_feat[1].features[1].float(),
final_info_sharing_multi_view_feat.features[1].float(),
],
}
# Downstream task prediction
with torch.autocast("cuda", enabled=False):
# Prediction heads
head_output1 = self._downstream_head(1, info_sharing_outputs, shape1)
head_output2 = self._downstream_head(2, info_sharing_outputs, shape2)
# Post-process outputs
final_output1 = self.adaptor(
AdaptorInput(adaptor_feature=head_output1.decoded_channels, output_shape_hw=shape1)
)
final_output2 = self.adaptor(
AdaptorInput(adaptor_feature=head_output2.decoded_channels, output_shape_hw=shape2)
)
# Convert outputs to dictionary
res1 = {
"pts3d": final_output1.value.permute(0, 2, 3, 1).contiguous(),
"conf": final_output1.confidence.permute(0, 2, 3, 1).contiguous(),
}
res2 = {
"pts3d_in_other_view": final_output2.value.permute(0, 2, 3, 1).contiguous(),
"conf": final_output2.confidence.permute(0, 2, 3, 1).contiguous(),
}
return res1, res2