from typing import Tuple, List, Optional, Dict import torch import torch.nn as nn from huggingface_hub import PyTorchModelHubMixin from longstream.utils.vendor.dust3r.utils.misc import freeze_all_params from longstream.utils.vendor.models.components.aggregator.streamaggregator import ( STreamAggregator, ) from longstream.utils.vendor.models.components.heads.camera_head import ( CameraHead, RelPoseHead, ) from longstream.utils.vendor.models.components.heads.dpt_head import DPTHead class LongStream(nn.Module, PyTorchModelHubMixin): def __init__( self, img_size=518, patch_size=14, embed_dim=1024, freeze="none", rel_pose_head_cfg=None, use_role_embedding=True, enable_scale_token=False, scale_token_config=None, disable_keyframe_distinction=False, enable_camera_head=True, use_segment_mask=False, use_3d_rope=False, rope_freq=100, window_size=5000, ): super().__init__() self.img_size = img_size self.patch_size = patch_size self.embed_dim = embed_dim self.enable_scale_token = enable_scale_token self.enable_camera_head = enable_camera_head self.window_size = window_size self.aggregator = STreamAggregator( img_size=img_size, patch_size=patch_size, embed_dim=embed_dim, use_role_embedding=use_role_embedding, disable_keyframe_distinction=disable_keyframe_distinction, use_segment_mask=use_segment_mask, use_3d_rope=use_3d_rope, rope_freq=rope_freq, window_size=window_size, ) if self.enable_camera_head: self.camera_head = CameraHead(dim_in=2 * embed_dim, window_size=window_size) else: self.camera_head = None self.point_head = DPTHead( dim_in=2 * embed_dim, output_dim=4, activation="inv_log", conf_activation="expp1", ) self.depth_head = DPTHead( dim_in=2 * embed_dim, output_dim=2, activation="exp", conf_activation="expp1", ) self.rel_pose_head = None self.reinit_camera_head_when_rel_enabled = False if rel_pose_head_cfg is not None: enable = rel_pose_head_cfg.get("enabled", True) if enable: head_cfg = { "dim_in": 2 * embed_dim, "trunk_depth": rel_pose_head_cfg.get("trunk_depth", 4), "pose_mode": rel_pose_head_cfg.get("pose_mode", "SE3"), "num_heads": rel_pose_head_cfg.get("num_heads", 16), "mlp_ratio": rel_pose_head_cfg.get("mlp_ratio", 4), "init_values": rel_pose_head_cfg.get("init_values", 0.01), "trans_act": rel_pose_head_cfg.get("trans_act", "linear"), "quat_act": rel_pose_head_cfg.get("quat_act", "linear"), "fl_act": rel_pose_head_cfg.get("fl_act", "relu"), "use_global_scale": rel_pose_head_cfg.get( "use_global_scale", False ), "use_pair_cross_attn": rel_pose_head_cfg.get( "use_pair_cross_attn", False ), "detach_reference": rel_pose_head_cfg.get( "detach_reference", False ), "xattn_temperature": rel_pose_head_cfg.get( "xattn_temperature", 1.0 ), "use_precat": rel_pose_head_cfg.get("use_precat", False), "use_kf_role_embed": rel_pose_head_cfg.get( "use_kf_role_embed", True ), "kf_role_embed_init_std": rel_pose_head_cfg.get( "kf_role_embed_init_std", 0.02 ), "window_size": window_size, } self.rel_pose_head = RelPoseHead(**head_cfg) self.reinit_camera_head_when_rel_enabled = rel_pose_head_cfg.get( "reinit_camera_head", False ) if self.reinit_camera_head_when_rel_enabled: pass if self.enable_scale_token: self._init_scale_components(scale_token_config or {}) self.set_freeze(freeze) def reinitialize_camera_head(self): """ Reinitialize camera_head with fresh weights. This is useful when: 1. Loading a pretrained checkpoint that has camera_head weights 2. But we want to train camera_head from scratch with new settings (e.g., quaternion normalization) This method should be called AFTER checkpoint loading. """ old_camera_head = self.camera_head dim_in = old_camera_head.token_norm.normalized_shape[0] self.camera_head = CameraHead(dim_in=dim_in) device = next(old_camera_head.parameters()).device self.camera_head = self.camera_head.to(device) def _init_scale_components(self, config): self.scale_token = nn.Parameter(torch.zeros(self.embed_dim)) torch.nn.init.trunc_normal_(self.scale_token, std=0.02) self.scale_head = nn.Sequential( nn.Linear(2 * self.embed_dim, 256), nn.ReLU(), nn.Linear(256, 128), nn.ReLU(), nn.Linear(128, 1), ) for m in self.scale_head.modules(): if isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight, gain=1.0) if m.bias is not None: nn.init.constant_(m.bias, 0.0) import math nn.init.constant_(self.scale_head[-1].bias, math.log(30.0)) def set_freeze(self, freeze): self.freeze = freeze to_be_frozen = { "none": [], "encoder": [self.aggregator.patch_embed], } freeze_all_params(to_be_frozen[freeze]) def forward( self, images: torch.Tensor, mode: str = "causal", aggregator_kv_cache_list: Optional[List[List[torch.Tensor]]] = None, camera_head_kv_cache_list: Optional[List[List[List[torch.Tensor]]]] = None, rel_pose_inputs: Optional[Dict] = None, is_keyframe: Optional[torch.Tensor] = None, ): if len(images.shape) == 4: images = images.unsqueeze(0) batch_size = images.shape[0] additional_tokens = None if self.enable_scale_token: scale_token_base = self.scale_token.unsqueeze(0).repeat(batch_size, 1) additional_tokens = scale_token_base.unsqueeze(-1) keyframe_indices = None if rel_pose_inputs is not None and "keyframe_indices" in rel_pose_inputs: keyframe_indices = rel_pose_inputs["keyframe_indices"] if aggregator_kv_cache_list is not None: ( aggregated_tokens_list, patch_start_idx, aggregator_kv_cache_list, _, ) = self.aggregator( images, mode=mode, kv_cache_list=aggregator_kv_cache_list, is_keyframe=is_keyframe, keyframe_indices=keyframe_indices, additional_tokens=additional_tokens, reorder_keyframes_first=False, ) else: aggregated_tokens_list, patch_start_idx, _ = self.aggregator( images, mode=mode, is_keyframe=is_keyframe, keyframe_indices=keyframe_indices, additional_tokens=additional_tokens, reorder_keyframes_first=False, ) predictions = {} predicted_scale_factor = None if self.enable_scale_token and additional_tokens is not None: if len(aggregated_tokens_list) > 0: last_layer_features = aggregated_tokens_list[-1] scale_token_idx = patch_start_idx - 1 scale_token_output_features = last_layer_features[ :, :, scale_token_idx, : ] scale_token_output_features = scale_token_output_features.mean(dim=1) scale_logits = self.scale_head(scale_token_output_features).squeeze(-1) predicted_scale_factor = torch.exp(scale_logits) predictions["predicted_scale_factor"] = predicted_scale_factor predictions["scale_token_features"] = scale_token_output_features if self.enable_camera_head and self.camera_head is not None: if camera_head_kv_cache_list is not None: pose_enc_list, camera_head_kv_cache_list = self.camera_head( aggregated_tokens_list, mode=mode, kv_cache_list=camera_head_kv_cache_list, ) else: pose_enc_list = self.camera_head(aggregated_tokens_list, mode=mode) final_pose_enc = pose_enc_list[-1] if self.enable_scale_token and predicted_scale_factor is not None: scale = predicted_scale_factor.view(-1, 1, 1) scaled_t = final_pose_enc[..., :3] * scale scaled_pose_enc = torch.cat([scaled_t, final_pose_enc[..., 3:]], dim=-1) predictions["pose_enc"] = scaled_pose_enc else: predictions["pose_enc"] = final_pose_enc if self.training: if self.enable_scale_token and predicted_scale_factor is not None: scale = predicted_scale_factor.view(-1, 1, 1) scaled_pose_enc_list = [] for pose_enc in pose_enc_list: scaled_t = pose_enc[..., :3] * scale scaled_pose_enc = torch.cat( [scaled_t, pose_enc[..., 3:]], dim=-1 ) scaled_pose_enc_list.append(scaled_pose_enc) predictions["pose_enc_list"] = scaled_pose_enc_list else: predictions["pose_enc_list"] = pose_enc_list if self.rel_pose_head is not None and rel_pose_inputs is not None: rel_kwargs = dict( aggregated_tokens_list=aggregated_tokens_list, keyframe_indices=rel_pose_inputs.get("keyframe_indices"), is_keyframe=rel_pose_inputs.get("is_keyframe", is_keyframe), num_iterations=rel_pose_inputs.get("num_iterations", 4), mode=mode, kv_cache_list=rel_pose_inputs.get("kv_cache_list"), ) rel_kwargs = {k: v for k, v in rel_kwargs.items() if v is not None} rel_result = self.rel_pose_head(**rel_kwargs) if isinstance(rel_result, dict): pose_enc = rel_result["pose_enc"] if pose_enc.dtype != torch.float32: pose_enc = pose_enc.float() if self.enable_scale_token and predicted_scale_factor is not None: scale = predicted_scale_factor.view(-1, 1, 1) scaled_t = pose_enc[..., :3] * scale scaled_rel_pose_enc = torch.cat( [scaled_t, pose_enc[..., 3:]], dim=-1 ) predictions["rel_pose_enc"] = scaled_rel_pose_enc if "pose_enc_list" in rel_result: scaled_pose_enc_list = [] for iter_pose in rel_result["pose_enc_list"]: scaled_t = iter_pose[..., :3] * scale scaled_iter_pose = torch.cat( [scaled_t, iter_pose[..., 3:]], dim=-1 ) scaled_pose_enc_list.append(scaled_iter_pose) predictions["rel_pose_enc_list"] = scaled_pose_enc_list else: predictions["rel_pose_enc"] = pose_enc if "pose_enc_list" in rel_result: predictions["rel_pose_enc_list"] = rel_result["pose_enc_list"] predictions["is_keyframe"] = rel_result.get("is_keyframe") predictions["keyframe_indices"] = rel_result.get("keyframe_indices") if "global_scale" in rel_result: predictions["global_scale"] = rel_result["global_scale"] if "kv_cache_list" in rel_result: predictions["rel_pose_kv_cache_list"] = rel_result["kv_cache_list"] if self.point_head is not None: pts3d, pts3d_conf = self.point_head( aggregated_tokens_list, images=images, patch_start_idx=patch_start_idx ) if self.enable_scale_token and predicted_scale_factor is not None: scale = predicted_scale_factor.view(-1, 1, 1, 1, 1) predictions["world_points"] = pts3d * scale else: predictions["world_points"] = pts3d predictions["world_points_conf"] = pts3d_conf if self.depth_head is not None: depth, depth_conf = self.depth_head( aggregated_tokens_list, images=images, patch_start_idx=patch_start_idx ) if self.enable_scale_token and predicted_scale_factor is not None: scale = predicted_scale_factor.view(-1, 1, 1, 1, 1) predictions["depth"] = depth * scale else: predictions["depth"] = depth predictions["depth_conf"] = depth_conf if aggregator_kv_cache_list is not None: predictions["aggregator_kv_cache_list"] = aggregator_kv_cache_list if camera_head_kv_cache_list is not None: predictions["camera_head_kv_cache_list"] = camera_head_kv_cache_list if not self.training: predictions["images"] = images return predictions