Spaces:
Running
on
Zero
Running
on
Zero
File size: 16,492 Bytes
c8b42eb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 |
from typing import List, Tuple
import torch
import torch.nn as nn
from uniception.models.encoders import ViTEncoderInput
from uniception.models.encoders.croco import CroCoEncoder
from uniception.models.encoders.image_normalizations import IMAGE_NORMALIZATION_DICT
from uniception.models.info_sharing.base import MultiViewTransformerInput
from uniception.models.info_sharing.cross_attention_transformer import (
MultiViewCrossAttentionTransformer,
MultiViewCrossAttentionTransformerIFR,
)
from uniception.models.libs.croco.pos_embed import RoPE2D, get_2d_sincos_pos_embed
from uniception.models.prediction_heads.adaptors import PointMapWithConfidenceAdaptor
from uniception.models.prediction_heads.base import AdaptorInput, PredictionHeadInput, PredictionHeadLayeredInput
from uniception.models.prediction_heads.dpt import DPTFeature, DPTRegressionProcessor
from uniception.models.prediction_heads.linear import LinearFeature
def is_symmetrized(gt1, gt2):
"Function to check if input pairs are symmetrized, i.e., (a, b) and (b, a) always exist in the input"
x = gt1["instance"]
y = gt2["instance"]
if len(x) == len(y) and len(x) == 1:
return False # special case of batchsize 1
ok = True
for i in range(0, len(x), 2):
ok = ok and (x[i] == y[i + 1]) and (x[i + 1] == y[i])
return ok
def interleave(tensor1, tensor2):
"Interleave two tensors along the first dimension (used to avoid redundant encoding for symmetrized pairs)"
res1 = torch.stack((tensor1, tensor2), dim=1).flatten(0, 1)
res2 = torch.stack((tensor2, tensor1), dim=1).flatten(0, 1)
return res1, res2
class DUSt3R(nn.Module):
"DUSt3R defined with UniCeption Modules"
def __init__(
self,
name: str,
data_norm_type: str = "dust3r",
img_size: tuple = (224, 224),
patch_embed_cls: str = "PatchEmbedDust3R",
pred_head_type: str = "linear",
pred_head_output_dim: int = 4,
pred_head_feature_dim: int = 256,
depth_mode: Tuple[str, float, float] = ("exp", -float("inf"), float("inf")),
conf_mode: Tuple[str, float, float] = ("exp", 1, float("inf")),
pos_embed: str = "RoPE100",
pretrained_checkpoint_path: str = None,
pretrained_encoder_checkpoint_path: str = None,
pretrained_info_sharing_checkpoint_path: str = None,
pretrained_pred_head_checkpoint_paths: List[str] = [None, None],
pretrained_pred_head_regressor_checkpoint_paths: List[str] = [None, None],
override_encoder_checkpoint_attributes: bool = False,
*args,
**kwargs,
):
"""
Two-view model containing siamese encoders followed by a two-view cross-attention transformer and respective downstream heads.
The goal is to output scene representation directly, both images in view1's frame (hence the asymmetry).
Args:
name (str): Name of the model.
data_norm_type (str): Type of data normalization. (default: "dust3r")
img_size (tuple): Size of input images. (default: (224, 224))
patch_embed_cls (str): Class for patch embedding. (default: "PatchEmbedDust3R"). Options:
- "PatchEmbedDust3R"
- "ManyAR_PatchEmbed"
pred_head_type (str): Type of prediction head. (default: "linear"). Options:
- "linear"
- "dpt"
pred_head_output_dim (int): Output dimension of prediction head. (default: 4)
pred_head_feature_dim (int): Feature dimension of prediction head. (default: 256)
depth_mode (Tuple[str, float, float]): Depth mode settings (mode=['linear', 'square', 'exp'], vmin, vmax). (default: ('exp', -inf, inf))
conf_mode (Tuple[str, float, float]): Confidence mode settings (mode=['linear', 'square', 'exp'], vmin, vmax). (default: ('exp', 1, inf))
pos_embed (str): Position embedding type. (default: 'RoPE100')
landscape_only (bool): Run downstream head only in landscape orientation. (default: True)
pretrained_checkpoint_path (str): Path to pretrained checkpoint. (default: None)
pretrained_encoder_checkpoint_path (str): Path to pretrained encoder checkpoint. (default: None)
pretrained_info_sharing_checkpoint_path (str): Path to pretrained info_sharing checkpoint. (default: None)
pretrained_pred_head_checkpoint_paths (List[str]): Paths to pretrained prediction head checkpoints. (default: None)
pretrained_pred_head_regressor_checkpoint_paths (List[str]): Paths to pretrained prediction head regressor checkpoints. (default: None)
override_encoder_checkpoint_attributes (bool): Whether to override encoder checkpoint attributes. (default: False)
"""
super().__init__(*args, **kwargs)
# Initalize the attributes
self.name = name
self.data_norm_type = data_norm_type
self.img_size = img_size
self.patch_embed_cls = patch_embed_cls
self.pred_head_type = pred_head_type
self.pred_head_output_dim = pred_head_output_dim
self.depth_mode = depth_mode
self.conf_mode = conf_mode
self.pos_embed = pos_embed
self.pretrained_checkpoint_path = pretrained_checkpoint_path
self.pretrained_encoder_checkpoint_path = pretrained_encoder_checkpoint_path
self.pretrained_info_sharing_checkpoint_path = pretrained_info_sharing_checkpoint_path
self.pretrained_pred_head_checkpoint_paths = pretrained_pred_head_checkpoint_paths
self.pretrained_pred_head_regressor_checkpoint_paths = pretrained_pred_head_regressor_checkpoint_paths
self.override_encoder_checkpoint_attributes = override_encoder_checkpoint_attributes
# Initialize RoPE for the CroCo Encoder & Two-View Cross Attention Transformer
freq = float(pos_embed[len("RoPE") :])
self.rope = RoPE2D(freq=freq)
# Initialize Encoder
self.encoder = CroCoEncoder(
name=name,
data_norm_type=data_norm_type,
patch_embed_cls=patch_embed_cls,
img_size=img_size,
pretrained_checkpoint_path=pretrained_encoder_checkpoint_path,
override_checkpoint_attributes=override_encoder_checkpoint_attributes,
)
# Initialize Multi-View Cross Attention Transformer
if self.pred_head_type == "linear":
# Returns only normalized last layer features
self.info_sharing = MultiViewCrossAttentionTransformer(
name="base_info_sharing",
input_embed_dim=self.encoder.enc_embed_dim,
num_views=2,
custom_positional_encoding=self.rope,
pretrained_checkpoint_path=pretrained_info_sharing_checkpoint_path,
)
elif self.pred_head_type == "dpt":
# Returns intermediate features and normalized last layer features
self.info_sharing = MultiViewCrossAttentionTransformerIFR(
name="base_info_sharing",
input_embed_dim=self.encoder.enc_embed_dim,
num_views=2,
indices=[5, 8],
norm_intermediate=False,
custom_positional_encoding=self.rope,
pretrained_checkpoint_path=pretrained_info_sharing_checkpoint_path,
)
else:
raise ValueError(f"Invalid prediction head type: {pred_head_type}. Must be 'linear' or 'dpt'.")
# Initialize Prediction Heads
if pred_head_type == "linear":
# Initialize Prediction Head 1
self.head1 = LinearFeature(
input_feature_dim=self.info_sharing.dim,
output_dim=pred_head_output_dim,
patch_size=self.encoder.patch_size,
pretrained_checkpoint_path=pretrained_pred_head_checkpoint_paths[0],
)
# Initialize Prediction Head 2
self.head2 = LinearFeature(
input_feature_dim=self.info_sharing.dim,
output_dim=pred_head_output_dim,
patch_size=self.encoder.patch_size,
pretrained_checkpoint_path=pretrained_pred_head_checkpoint_paths[1],
)
elif pred_head_type == "dpt":
# Initialze Predction Head 1
self.dpt_feature_head1 = DPTFeature(
patch_size=self.encoder.patch_size,
hooks=[0, 1, 2, 3],
input_feature_dims=[self.encoder.enc_embed_dim] + [self.info_sharing.dim] * 3,
feature_dim=pred_head_feature_dim,
pretrained_checkpoint_path=pretrained_pred_head_checkpoint_paths[0],
)
self.dpt_regressor_head1 = DPTRegressionProcessor(
input_feature_dim=pred_head_feature_dim,
output_dim=pred_head_output_dim,
pretrained_checkpoint_path=pretrained_pred_head_regressor_checkpoint_paths[0],
)
self.head1 = nn.Sequential(self.dpt_feature_head1, self.dpt_regressor_head1)
# Initialize Prediction Head 2
self.dpt_feature_head2 = DPTFeature(
patch_size=self.encoder.patch_size,
hooks=[0, 1, 2, 3],
input_feature_dims=[self.encoder.enc_embed_dim] + [self.info_sharing.dim] * 3,
feature_dim=pred_head_feature_dim,
pretrained_checkpoint_path=pretrained_pred_head_checkpoint_paths[1],
)
self.dpt_regressor_head2 = DPTRegressionProcessor(
input_feature_dim=pred_head_feature_dim,
output_dim=pred_head_output_dim,
pretrained_checkpoint_path=pretrained_pred_head_regressor_checkpoint_paths[1],
)
self.head2 = nn.Sequential(self.dpt_feature_head2, self.dpt_regressor_head2)
# Initialize Final Output Adaptor
self.adaptor = PointMapWithConfidenceAdaptor(
name="pointmap",
pointmap_mode=depth_mode[0],
pointmap_vmin=depth_mode[1],
pointmap_vmax=depth_mode[2],
confidence_type=conf_mode[0],
confidence_vmin=conf_mode[1],
confidence_vmax=conf_mode[2],
)
# Load pretrained weights
if self.pretrained_checkpoint_path is not None:
print(f"Loading pretrained DUSt3R weights from {self.pretrained_checkpoint_path} ...")
ckpt = torch.load(self.pretrained_checkpoint_path, weights_only=False)
print(self.load_state_dict(ckpt["model"]))
def _encode_image_pairs(self, img1, img2, data_norm_type):
"Encode two different batches of images (each batch can have different image shape)"
if img1.shape[-2:] == img2.shape[-2:]:
encoder_input = ViTEncoderInput(image=torch.cat((img1, img2), dim=0), data_norm_type=data_norm_type)
encoder_output = self.encoder(encoder_input)
out, out2 = encoder_output.features.chunk(2, dim=0)
else:
encoder_input = ViTEncoderInput(image=img1, data_norm_type=data_norm_type)
out = self.encoder(encoder_input)
out = out.features
encoder_input2 = ViTEncoderInput(image=img2)
out2 = self.encoder(encoder_input2)
out2 = out2.features
return out, out2
def _encode_symmetrized(self, view1, view2):
"Encode image pairs accounting for symmetrization, i.e., (a, b) and (b, a) always exist in the input"
img1 = view1["img"]
img2 = view2["img"]
if is_symmetrized(view1, view2):
# Computing half of forward pass'
feat1, feat2 = self._encode_image_pairs(img1[::2], img2[::2], data_norm_type=view1["data_norm_type"])
feat1, feat2 = interleave(feat1, feat2)
else:
feat1, feat2 = self._encode_image_pairs(img1, img2, data_norm_type=view1["data_norm_type"])
return feat1, feat2
def _downstream_head(self, head_num, decout, img_shape):
"Run the respective prediction heads"
head = getattr(self, f"head{head_num}")
if self.pred_head_type == "linear":
head_input = PredictionHeadInput(last_feature=decout[f"{head_num}"])
elif self.pred_head_type == "dpt":
head_input = PredictionHeadLayeredInput(list_features=decout[f"{head_num}"], target_output_shape=img_shape)
return head(head_input)
def forward(self, view1, view2):
"""
Forward pass for DUSt3R performing the following operations:
1. Encodes the two input views (images).
2. Combines the encoded features using a two-view cross-attention transformer.
3. Passes the combined features through the respective prediction heads.
4. Returns the processed final outputs for both views.
Args:
view1 (dict): Dictionary containing the first view's images and instance information.
"img" is a required key and value is a tensor of shape (B, C, H, W).
view2 (dict): Dictionary containing the second view's images and instance information.
"img" is a required key and value is a tensor of shape (B, C, H, W).
Returns:
Tuple[dict, dict]: A tuple containing the final outputs for both views.
"""
# Get input shapes
_, _, height1, width1 = view1["img"].shape
_, _, height2, width2 = view2["img"].shape
shape1 = (int(height1), int(width1))
shape2 = (int(height2), int(width2))
# Encode the two images --> Each feat output: BCHW features (batch_size, feature_dim, feature_height, feature_width)
feat1, feat2 = self._encode_symmetrized(view1, view2)
# Combine all images into view-centric representation
info_sharing_input = MultiViewTransformerInput(features=[feat1, feat2])
if self.pred_head_type == "linear":
final_info_sharing_multi_view_feat = self.info_sharing(info_sharing_input)
elif self.pred_head_type == "dpt":
final_info_sharing_multi_view_feat, intermediate_info_sharing_multi_view_feat = self.info_sharing(
info_sharing_input
)
if self.pred_head_type == "linear":
# Define feature dictionary for linear head
info_sharing_outputs = {
"1": final_info_sharing_multi_view_feat.features[0].float(),
"2": final_info_sharing_multi_view_feat.features[1].float(),
}
elif self.pred_head_type == "dpt":
# Define feature dictionary for DPT head
info_sharing_outputs = {
"1": [
feat1.float(),
intermediate_info_sharing_multi_view_feat[0].features[0].float(),
intermediate_info_sharing_multi_view_feat[1].features[0].float(),
final_info_sharing_multi_view_feat.features[0].float(),
],
"2": [
feat2.float(),
intermediate_info_sharing_multi_view_feat[0].features[1].float(),
intermediate_info_sharing_multi_view_feat[1].features[1].float(),
final_info_sharing_multi_view_feat.features[1].float(),
],
}
# Downstream task prediction
with torch.autocast("cuda", enabled=False):
# Prediction heads
head_output1 = self._downstream_head(1, info_sharing_outputs, shape1)
head_output2 = self._downstream_head(2, info_sharing_outputs, shape2)
# Post-process outputs
final_output1 = self.adaptor(
AdaptorInput(adaptor_feature=head_output1.decoded_channels, output_shape_hw=shape1)
)
final_output2 = self.adaptor(
AdaptorInput(adaptor_feature=head_output2.decoded_channels, output_shape_hw=shape2)
)
# Convert outputs to dictionary
res1 = {
"pts3d": final_output1.value.permute(0, 2, 3, 1).contiguous(),
"conf": final_output1.confidence.permute(0, 2, 3, 1).contiguous(),
}
res2 = {
"pts3d_in_other_view": final_output2.value.permute(0, 2, 3, 1).contiguous(),
"conf": final_output2.confidence.permute(0, 2, 3, 1).contiguous(),
}
return res1, res2
|