LongStream / longstream /models /longstream.py
Cc
init
e340a84
from typing import Tuple, List, Optional, Dict
import torch
import torch.nn as nn
from huggingface_hub import PyTorchModelHubMixin
from longstream.utils.vendor.dust3r.utils.misc import freeze_all_params
from longstream.utils.vendor.models.components.aggregator.streamaggregator import (
STreamAggregator,
)
from longstream.utils.vendor.models.components.heads.camera_head import (
CameraHead,
RelPoseHead,
)
from longstream.utils.vendor.models.components.heads.dpt_head import DPTHead
class LongStream(nn.Module, PyTorchModelHubMixin):
def __init__(
self,
img_size=518,
patch_size=14,
embed_dim=1024,
freeze="none",
rel_pose_head_cfg=None,
use_role_embedding=True,
enable_scale_token=False,
scale_token_config=None,
disable_keyframe_distinction=False,
enable_camera_head=True,
use_segment_mask=False,
use_3d_rope=False,
rope_freq=100,
window_size=5000,
):
super().__init__()
self.img_size = img_size
self.patch_size = patch_size
self.embed_dim = embed_dim
self.enable_scale_token = enable_scale_token
self.enable_camera_head = enable_camera_head
self.window_size = window_size
self.aggregator = STreamAggregator(
img_size=img_size,
patch_size=patch_size,
embed_dim=embed_dim,
use_role_embedding=use_role_embedding,
disable_keyframe_distinction=disable_keyframe_distinction,
use_segment_mask=use_segment_mask,
use_3d_rope=use_3d_rope,
rope_freq=rope_freq,
window_size=window_size,
)
if self.enable_camera_head:
self.camera_head = CameraHead(dim_in=2 * embed_dim, window_size=window_size)
else:
self.camera_head = None
self.point_head = DPTHead(
dim_in=2 * embed_dim,
output_dim=4,
activation="inv_log",
conf_activation="expp1",
)
self.depth_head = DPTHead(
dim_in=2 * embed_dim,
output_dim=2,
activation="exp",
conf_activation="expp1",
)
self.rel_pose_head = None
self.reinit_camera_head_when_rel_enabled = False
if rel_pose_head_cfg is not None:
enable = rel_pose_head_cfg.get("enabled", True)
if enable:
head_cfg = {
"dim_in": 2 * embed_dim,
"trunk_depth": rel_pose_head_cfg.get("trunk_depth", 4),
"pose_mode": rel_pose_head_cfg.get("pose_mode", "SE3"),
"num_heads": rel_pose_head_cfg.get("num_heads", 16),
"mlp_ratio": rel_pose_head_cfg.get("mlp_ratio", 4),
"init_values": rel_pose_head_cfg.get("init_values", 0.01),
"trans_act": rel_pose_head_cfg.get("trans_act", "linear"),
"quat_act": rel_pose_head_cfg.get("quat_act", "linear"),
"fl_act": rel_pose_head_cfg.get("fl_act", "relu"),
"use_global_scale": rel_pose_head_cfg.get(
"use_global_scale", False
),
"use_pair_cross_attn": rel_pose_head_cfg.get(
"use_pair_cross_attn", False
),
"detach_reference": rel_pose_head_cfg.get(
"detach_reference", False
),
"xattn_temperature": rel_pose_head_cfg.get(
"xattn_temperature", 1.0
),
"use_precat": rel_pose_head_cfg.get("use_precat", False),
"use_kf_role_embed": rel_pose_head_cfg.get(
"use_kf_role_embed", True
),
"kf_role_embed_init_std": rel_pose_head_cfg.get(
"kf_role_embed_init_std", 0.02
),
"window_size": window_size,
}
self.rel_pose_head = RelPoseHead(**head_cfg)
self.reinit_camera_head_when_rel_enabled = rel_pose_head_cfg.get(
"reinit_camera_head", False
)
if self.reinit_camera_head_when_rel_enabled:
pass
if self.enable_scale_token:
self._init_scale_components(scale_token_config or {})
self.set_freeze(freeze)
def reinitialize_camera_head(self):
"""
Reinitialize camera_head with fresh weights.
This is useful when:
1. Loading a pretrained checkpoint that has camera_head weights
2. But we want to train camera_head from scratch with new settings (e.g., quaternion normalization)
This method should be called AFTER checkpoint loading.
"""
old_camera_head = self.camera_head
dim_in = old_camera_head.token_norm.normalized_shape[0]
self.camera_head = CameraHead(dim_in=dim_in)
device = next(old_camera_head.parameters()).device
self.camera_head = self.camera_head.to(device)
def _init_scale_components(self, config):
self.scale_token = nn.Parameter(torch.zeros(self.embed_dim))
torch.nn.init.trunc_normal_(self.scale_token, std=0.02)
self.scale_head = nn.Sequential(
nn.Linear(2 * self.embed_dim, 256),
nn.ReLU(),
nn.Linear(256, 128),
nn.ReLU(),
nn.Linear(128, 1),
)
for m in self.scale_head.modules():
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight, gain=1.0)
if m.bias is not None:
nn.init.constant_(m.bias, 0.0)
import math
nn.init.constant_(self.scale_head[-1].bias, math.log(30.0))
def set_freeze(self, freeze):
self.freeze = freeze
to_be_frozen = {
"none": [],
"encoder": [self.aggregator.patch_embed],
}
freeze_all_params(to_be_frozen[freeze])
def forward(
self,
images: torch.Tensor,
mode: str = "causal",
aggregator_kv_cache_list: Optional[List[List[torch.Tensor]]] = None,
camera_head_kv_cache_list: Optional[List[List[List[torch.Tensor]]]] = None,
rel_pose_inputs: Optional[Dict] = None,
is_keyframe: Optional[torch.Tensor] = None,
):
if len(images.shape) == 4:
images = images.unsqueeze(0)
batch_size = images.shape[0]
additional_tokens = None
if self.enable_scale_token:
scale_token_base = self.scale_token.unsqueeze(0).repeat(batch_size, 1)
additional_tokens = scale_token_base.unsqueeze(-1)
keyframe_indices = None
if rel_pose_inputs is not None and "keyframe_indices" in rel_pose_inputs:
keyframe_indices = rel_pose_inputs["keyframe_indices"]
if aggregator_kv_cache_list is not None:
(
aggregated_tokens_list,
patch_start_idx,
aggregator_kv_cache_list,
_,
) = self.aggregator(
images,
mode=mode,
kv_cache_list=aggregator_kv_cache_list,
is_keyframe=is_keyframe,
keyframe_indices=keyframe_indices,
additional_tokens=additional_tokens,
reorder_keyframes_first=False,
)
else:
aggregated_tokens_list, patch_start_idx, _ = self.aggregator(
images,
mode=mode,
is_keyframe=is_keyframe,
keyframe_indices=keyframe_indices,
additional_tokens=additional_tokens,
reorder_keyframes_first=False,
)
predictions = {}
predicted_scale_factor = None
if self.enable_scale_token and additional_tokens is not None:
if len(aggregated_tokens_list) > 0:
last_layer_features = aggregated_tokens_list[-1]
scale_token_idx = patch_start_idx - 1
scale_token_output_features = last_layer_features[
:, :, scale_token_idx, :
]
scale_token_output_features = scale_token_output_features.mean(dim=1)
scale_logits = self.scale_head(scale_token_output_features).squeeze(-1)
predicted_scale_factor = torch.exp(scale_logits)
predictions["predicted_scale_factor"] = predicted_scale_factor
predictions["scale_token_features"] = scale_token_output_features
if self.enable_camera_head and self.camera_head is not None:
if camera_head_kv_cache_list is not None:
pose_enc_list, camera_head_kv_cache_list = self.camera_head(
aggregated_tokens_list,
mode=mode,
kv_cache_list=camera_head_kv_cache_list,
)
else:
pose_enc_list = self.camera_head(aggregated_tokens_list, mode=mode)
final_pose_enc = pose_enc_list[-1]
if self.enable_scale_token and predicted_scale_factor is not None:
scale = predicted_scale_factor.view(-1, 1, 1)
scaled_t = final_pose_enc[..., :3] * scale
scaled_pose_enc = torch.cat([scaled_t, final_pose_enc[..., 3:]], dim=-1)
predictions["pose_enc"] = scaled_pose_enc
else:
predictions["pose_enc"] = final_pose_enc
if self.training:
if self.enable_scale_token and predicted_scale_factor is not None:
scale = predicted_scale_factor.view(-1, 1, 1)
scaled_pose_enc_list = []
for pose_enc in pose_enc_list:
scaled_t = pose_enc[..., :3] * scale
scaled_pose_enc = torch.cat(
[scaled_t, pose_enc[..., 3:]], dim=-1
)
scaled_pose_enc_list.append(scaled_pose_enc)
predictions["pose_enc_list"] = scaled_pose_enc_list
else:
predictions["pose_enc_list"] = pose_enc_list
if self.rel_pose_head is not None and rel_pose_inputs is not None:
rel_kwargs = dict(
aggregated_tokens_list=aggregated_tokens_list,
keyframe_indices=rel_pose_inputs.get("keyframe_indices"),
is_keyframe=rel_pose_inputs.get("is_keyframe", is_keyframe),
num_iterations=rel_pose_inputs.get("num_iterations", 4),
mode=mode,
kv_cache_list=rel_pose_inputs.get("kv_cache_list"),
)
rel_kwargs = {k: v for k, v in rel_kwargs.items() if v is not None}
rel_result = self.rel_pose_head(**rel_kwargs)
if isinstance(rel_result, dict):
pose_enc = rel_result["pose_enc"]
if pose_enc.dtype != torch.float32:
pose_enc = pose_enc.float()
if self.enable_scale_token and predicted_scale_factor is not None:
scale = predicted_scale_factor.view(-1, 1, 1)
scaled_t = pose_enc[..., :3] * scale
scaled_rel_pose_enc = torch.cat(
[scaled_t, pose_enc[..., 3:]], dim=-1
)
predictions["rel_pose_enc"] = scaled_rel_pose_enc
if "pose_enc_list" in rel_result:
scaled_pose_enc_list = []
for iter_pose in rel_result["pose_enc_list"]:
scaled_t = iter_pose[..., :3] * scale
scaled_iter_pose = torch.cat(
[scaled_t, iter_pose[..., 3:]], dim=-1
)
scaled_pose_enc_list.append(scaled_iter_pose)
predictions["rel_pose_enc_list"] = scaled_pose_enc_list
else:
predictions["rel_pose_enc"] = pose_enc
if "pose_enc_list" in rel_result:
predictions["rel_pose_enc_list"] = rel_result["pose_enc_list"]
predictions["is_keyframe"] = rel_result.get("is_keyframe")
predictions["keyframe_indices"] = rel_result.get("keyframe_indices")
if "global_scale" in rel_result:
predictions["global_scale"] = rel_result["global_scale"]
if "kv_cache_list" in rel_result:
predictions["rel_pose_kv_cache_list"] = rel_result["kv_cache_list"]
if self.point_head is not None:
pts3d, pts3d_conf = self.point_head(
aggregated_tokens_list, images=images, patch_start_idx=patch_start_idx
)
if self.enable_scale_token and predicted_scale_factor is not None:
scale = predicted_scale_factor.view(-1, 1, 1, 1, 1)
predictions["world_points"] = pts3d * scale
else:
predictions["world_points"] = pts3d
predictions["world_points_conf"] = pts3d_conf
if self.depth_head is not None:
depth, depth_conf = self.depth_head(
aggregated_tokens_list, images=images, patch_start_idx=patch_start_idx
)
if self.enable_scale_token and predicted_scale_factor is not None:
scale = predicted_scale_factor.view(-1, 1, 1, 1, 1)
predictions["depth"] = depth * scale
else:
predictions["depth"] = depth
predictions["depth_conf"] = depth_conf
if aggregator_kv_cache_list is not None:
predictions["aggregator_kv_cache_list"] = aggregator_kv_cache_list
if camera_head_kv_cache_list is not None:
predictions["camera_head_kv_cache_list"] = camera_head_kv_cache_list
if not self.training:
predictions["images"] = images
return predictions