Cc
init
e340a84
from typing import List, Tuple, Optional, Union
import torch
import torch.nn as nn
from ..layers import Mlp
from ..layers.block import Block
from .head_act import activate_pose
class CameraHead(nn.Module):
"""
CameraHead predicts camera parameters from token representations using iterative refinement.
It applies a series of transformer blocks (the "trunk") to dedicated camera tokens.
"""
def __init__(
self,
dim_in: int = 2048,
trunk_depth: int = 4,
pose_encoding_type: str = "absT_quaR_FoV",
num_heads: int = 16,
mlp_ratio: int = 4,
init_values: float = 0.01,
trans_act: str = "linear",
quat_act: str = "linear",
fl_act: str = "relu",
window_size: int = 5,
):
super().__init__()
if pose_encoding_type == "absT_quaR_FoV":
self.target_dim = 9
else:
raise ValueError(f"Unsupported camera encoding type: {pose_encoding_type}")
self.trans_act = trans_act
self.quat_act = quat_act
self.fl_act = fl_act
self.trunk_depth = trunk_depth
self.window_size = window_size
self.trunk = nn.Sequential(
*[
Block(
dim=dim_in,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
init_values=init_values,
)
for _ in range(trunk_depth)
]
)
self.token_norm = nn.LayerNorm(dim_in)
self.trunk_norm = nn.LayerNorm(dim_in)
self.empty_pose_tokens = nn.Parameter(torch.zeros(1, 1, self.target_dim))
self.embed_pose = nn.Linear(self.target_dim, dim_in)
self.poseLN_modulation = nn.Sequential(
nn.SiLU(), nn.Linear(dim_in, 3 * dim_in, bias=True)
)
self.adaln_norm = nn.LayerNorm(dim_in, elementwise_affine=False, eps=1e-6)
self.pose_branch = Mlp(
in_features=dim_in,
hidden_features=dim_in // 2,
out_features=self.target_dim,
drop=0,
)
def _create_attn_mask(
self, S: int, mode: str, dtype: torch.dtype, device: torch.device
) -> Optional[torch.Tensor]:
N = S
mask = torch.zeros((N, N), dtype=dtype, device=device)
if mode == "causal":
for i in range(S):
curr_view_start = i
curr_view_end = i + 1
mask[curr_view_start:curr_view_end, curr_view_end:] = float("-inf")
elif mode == "window":
for i in range(S):
curr_view_start = i
curr_view_end = i + 1
mask[curr_view_start:curr_view_end, 1:] = float("-inf")
start_view = max(1, i - self.window_size + 1)
mask[curr_view_start:curr_view_end, start_view : (i + 1)] = 0
elif mode == "full":
mask = None
else:
raise NotImplementedError(f"Unknown attention mode: {mode}")
return mask
def forward(
self,
aggregated_tokens_list: list,
num_iterations: int = 4,
mode: str = "causal",
kv_cache_list: Optional[List[List[List[torch.Tensor]]]] = None,
) -> Union[list, Tuple[list, List[List[List[torch.Tensor]]]]]:
"""
Forward pass to predict camera parameters.
Args:
aggregated_tokens_list (list): List of token tensors from the network;
the last tensor is used for prediction.
num_iterations (int, optional): Number of iterative refinement steps. Defaults to 4.
mode (str): Global attention mode, could be either "causal", "window" or "full"
kv_cache_list (List[List[List[torch.Tensor]]]): List of cached key-value pairs for
each iterations and each attention layer of the camera head
Returns:
list: A list of predicted camera encodings (post-activation) from each iteration.
"""
tokens = aggregated_tokens_list[-1]
pose_tokens = tokens[:, :, 0]
pose_tokens = self.token_norm(pose_tokens)
B, S, C = pose_tokens.shape
attn_mask = None
if kv_cache_list is None:
attn_mask = self._create_attn_mask(
S, mode, pose_tokens.dtype, pose_tokens.device
)
pred_pose_enc_list = self.trunk_fn(
pose_tokens, num_iterations, attn_mask, kv_cache_list
)
return pred_pose_enc_list
def trunk_fn(
self,
pose_tokens: torch.Tensor,
num_iterations: int,
attn_mask: Optional[torch.Tensor],
kv_cache_list: Optional[List[List[List[torch.Tensor]]]] = None,
) -> Union[list, Tuple[list, List[List[List[torch.Tensor]]]]]:
"""
Iteratively refine camera pose predictions.
Args:
pose_tokens (torch.Tensor): Normalized camera tokens with shape [B, S, C].
num_iterations (int): Number of refinement iterations.
Returns:
list: List of activated camera encodings from each iteration.
"""
B, S, C = pose_tokens.shape
pred_pose_enc = None
pred_pose_enc_list = []
for iter in range(num_iterations):
if pred_pose_enc is None:
module_input = self.embed_pose(self.empty_pose_tokens.expand(B, S, -1))
else:
module_input = self.embed_pose(pred_pose_enc)
shift_msa, scale_msa, gate_msa = self.poseLN_modulation(module_input).chunk(
3, dim=-1
)
adaln_output = self.adaln_norm(pose_tokens)
modulated_output = modulate(adaln_output, shift_msa, scale_msa)
gated_output = gate_msa * modulated_output
pose_tokens_modulated = gated_output + pose_tokens
for i in range(self.trunk_depth):
if kv_cache_list is not None:
pose_tokens_modulated, kv_cache_list[iter][i] = self.trunk[i](
pose_tokens_modulated,
attn_mask=attn_mask,
kv_cache=kv_cache_list[iter][i],
)
else:
pose_tokens_modulated = self.trunk[i](
pose_tokens_modulated, attn_mask=attn_mask
)
trunk_norm_output = self.trunk_norm(pose_tokens_modulated)
pred_pose_enc_delta = self.pose_branch(trunk_norm_output)
if pred_pose_enc is None:
pred_pose_enc = pred_pose_enc_delta
else:
pred_pose_enc = pred_pose_enc + pred_pose_enc_delta
activated_pose = activate_pose(
pred_pose_enc,
trans_act=self.trans_act,
quat_act=self.quat_act,
fl_act=self.fl_act,
)
pred_pose_enc_list.append(activated_pose)
if kv_cache_list is not None:
return pred_pose_enc_list, kv_cache_list
else:
return pred_pose_enc_list
def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
"""
Modulate the input tensor using scaling and shifting parameters.
"""
return x * (1 + scale) + shift
class RelPoseHead(nn.Module):
"""
Enhanced Relative Pose Head for dynamic keyframe-based pose prediction.
Key features:
1. True relative pose prediction (not incremental from fixed anchor)
2. Dynamic keyframe switching support
3. SE(3) and Sim(3) pose modes
4. Role-aware processing for keyframes vs non-keyframes
"""
def __init__(
self,
dim_in: int = 2048,
trunk_depth: int = 4,
pose_mode: str = "SE3",
num_heads: int = 16,
mlp_ratio: int = 4,
init_values: float = 0.01,
trans_act: str = "linear",
quat_act: str = "linear",
fl_act: str = "relu",
use_global_scale: bool = False,
use_pair_cross_attn: bool = False,
detach_reference: bool = False,
xattn_temperature: float = 1.0,
use_precat: bool = False,
use_kf_role_embed: bool = True,
kf_role_embed_init_std: float = 0.02,
window_size: int = 50000,
):
super().__init__()
self.pose_mode = pose_mode
self.use_global_scale = use_global_scale and (pose_mode == "Sim3")
self.use_pair_cross_attn = use_pair_cross_attn
self.detach_reference = detach_reference
self.xattn_temperature = xattn_temperature
self.use_precat = use_precat
self.use_kf_role_embed = use_kf_role_embed
self.kf_role_embed_init_std = kf_role_embed_init_std
self.target_dim = 9
self.trans_act = trans_act
self.quat_act = quat_act
self.fl_act = fl_act
self.trunk_depth = trunk_depth
self.window_size = 50000
self.trunk = nn.Sequential(
*[
Block(
dim=dim_in,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
init_values=init_values,
)
for _ in range(trunk_depth)
]
)
self.token_norm = nn.LayerNorm(dim_in)
self.trunk_norm = nn.LayerNorm(dim_in)
self.empty_pose_tokens = nn.Parameter(torch.zeros(1, 1, self.target_dim))
self.embed_pose = nn.Linear(self.target_dim, dim_in)
self.poseLN_modulation = nn.Sequential(
nn.SiLU(), nn.Linear(dim_in, 3 * dim_in, bias=True)
)
self.adaln_norm = nn.LayerNorm(dim_in, elementwise_affine=False, eps=1e-6)
self.pose_branch = Mlp(
in_features=dim_in,
hidden_features=dim_in // 2,
out_features=self.target_dim,
drop=0,
)
if self.use_global_scale:
self.global_scale = nn.Parameter(torch.ones(1))
if self.use_pair_cross_attn:
self.xattn_q = nn.Linear(dim_in, dim_in, bias=False)
self.xattn_k = nn.Linear(dim_in, dim_in, bias=False)
self.xattn_v = nn.Linear(dim_in, dim_in, bias=False)
self.xattn_out = nn.Linear(dim_in, dim_in, bias=False)
if self.use_precat:
self.precat_proj = nn.Linear(dim_in * 2, dim_in, bias=True)
if self.use_kf_role_embed:
self.kf_role_embed = nn.Parameter(torch.randn(1, 1, dim_in))
nn.init.normal_(self.kf_role_embed, std=self.kf_role_embed_init_std)
else:
self.kf_role_embed = None
def _create_attn_mask(
self, S: int, mode: str, dtype: torch.dtype, device: torch.device
) -> Optional[torch.Tensor]:
"""Create attention mask for the given mode."""
N = S
if mode == "causal":
mask = torch.zeros((N, N), dtype=dtype, device=device)
for i in range(S):
mask[i, i + 1 :] = float("-inf")
return mask
elif mode == "window":
mask = torch.zeros((N, N), dtype=dtype, device=device)
for i in range(S):
mask[i, :] = float("-inf")
start = max(0, i - self.window_size + 1)
mask[i, start : i + 1] = 0
return mask
elif mode == "full":
return None
else:
raise NotImplementedError(f"Unknown attention mode: {mode}")
def forward(
self,
aggregated_tokens_list: list,
keyframe_indices: torch.Tensor,
is_keyframe: torch.Tensor,
num_iterations: int = 4,
mode: str = "causal",
kv_cache_list: Optional[List[List[List[torch.Tensor]]]] = None,
compute_switch_poses: bool = False,
):
"""
Forward pass for relative pose prediction.
Args:
aggregated_tokens_list: List of aggregated tokens from the network
keyframe_indices: Indices of reference keyframes for each frame [B, S]
is_keyframe: Boolean mask indicating keyframes [B, S]
num_iterations: Number of iterative refinement steps
mode: Attention mode ("causal", "window", or "full")
kv_cache_list: Optional KV cache for streaming
Returns:
dict containing:
- pose_enc: Predicted relative poses [B, S, 9]
- is_keyframe: Keyframe mask [B, S]
- keyframe_indices: Reference keyframe indices [B, S]
- global_scale: Global scale for Sim(3) mode (if applicable)
"""
mode = "causal"
tokens = aggregated_tokens_list[-1]
pose_tokens = tokens[:, :, 0]
pose_tokens = self.token_norm(pose_tokens)
B, S, C = pose_tokens.shape
if kv_cache_list is not None and S == 1:
if not hasattr(self, "_keyframe_tokens_cache"):
self._keyframe_tokens_cache = {}
self._current_frame_id = 0
self._frame_info = []
curr_is_kf = is_keyframe[0, 0].item() if is_keyframe is not None else True
curr_ref_idx = (
keyframe_indices[0, 0].item()
if keyframe_indices is not None
else self._current_frame_id
)
self._frame_info.append((curr_is_kf, curr_ref_idx))
if curr_is_kf:
self._keyframe_tokens_cache[
self._current_frame_id
] = pose_tokens.squeeze(1)
self._current_frame_id += 1
ref_tokens = None
if keyframe_indices is not None:
if kv_cache_list is not None and S == 1:
ref_frame_id = keyframe_indices[0, 0].item()
if ref_frame_id in self._keyframe_tokens_cache:
ref_tokens = self._keyframe_tokens_cache[ref_frame_id].unsqueeze(1)
else:
ref_tokens = pose_tokens
if self.detach_reference:
ref_tokens = ref_tokens.detach()
else:
total_frames = pose_tokens.shape[1]
ref_idx = (
keyframe_indices.clamp(0, total_frames - 1)
.unsqueeze(-1)
.expand(-1, -1, C)
)
ref_tokens = torch.gather(pose_tokens, dim=1, index=ref_idx)
if self.detach_reference:
ref_tokens = ref_tokens.detach()
if (
self.use_kf_role_embed
and ref_tokens is not None
and self.kf_role_embed is not None
):
current_indices = (
torch.arange(S, device=keyframe_indices.device)
.unsqueeze(0)
.expand(B, -1)
)
is_self_ref = current_indices == keyframe_indices
add_kf_embed_mask = (~is_self_ref).unsqueeze(-1).float()
ref_tokens = ref_tokens + add_kf_embed_mask * self.kf_role_embed.expand(
B, S, -1
)
if self.use_pair_cross_attn and (ref_tokens is not None):
q = self.xattn_q(pose_tokens)
k = self.xattn_k(ref_tokens)
v = self.xattn_v(ref_tokens)
scale = (q * k).sum(dim=-1, keepdim=True) / (C ** 0.5)
gate = torch.sigmoid(scale / self.xattn_temperature)
pair_info = self.xattn_out(gate * v)
pose_tokens = pose_tokens + pair_info
if self.use_precat and (ref_tokens is not None):
pose_tokens = self.precat_proj(torch.cat([pose_tokens, ref_tokens], dim=-1))
attn_mask = None
if keyframe_indices is not None:
ref = keyframe_indices
B = ref.shape[0]
j_indices = (
torch.arange(S, device=pose_tokens.device)
.view(1, 1, S)
.expand(B, S, -1)
)
is_ref_frame = j_indices == ref[:, :, None]
same_ref = ref[:, :, None] == ref[:, None, :]
can_attend_nonkf = is_ref_frame | same_ref
idx = torch.arange(S, device=pose_tokens.device)
if mode == "causal":
causal_mask = idx[None, :, None] >= idx[None, None, :]
prev_kf_mask = torch.zeros(
B, S, S, dtype=torch.bool, device=pose_tokens.device
)
for b in range(B):
kf_positions = [i for i in range(S) if is_keyframe[b, i]]
for kf_idx, kf_pos in enumerate(kf_positions):
if kf_idx == 0:
prev_kf_mask[b, kf_pos, : kf_pos + 1] = True
else:
prev_kf_pos = kf_positions[kf_idx - 1]
prev_kf_mask[b, kf_pos, prev_kf_pos : kf_pos + 1] = True
can_attend_kf = causal_mask.expand(B, -1, -1) & prev_kf_mask
mode_constraint = causal_mask
elif mode == "window":
causal_mask = idx[None, :, None] >= idx[None, None, :]
window_mask = (
idx[None, :, None] - idx[None, None, :]
) < self.window_size
window_causal = causal_mask & window_mask
can_attend_kf = window_causal.expand(B, -1, -1)
mode_constraint = window_causal
elif mode == "full":
can_attend_kf = torch.ones(
B, S, S, dtype=torch.bool, device=pose_tokens.device
)
mode_constraint = torch.ones(
1, S, S, dtype=torch.bool, device=pose_tokens.device
)
else:
raise NotImplementedError(f"Unknown mode: {mode}")
is_kf_expanded = is_keyframe[:, :, None].expand(-1, -1, S)
can_attend = torch.where(is_kf_expanded, can_attend_kf, can_attend_nonkf)
mask_bool = can_attend & mode_constraint
zero = torch.zeros(1, dtype=pose_tokens.dtype, device=pose_tokens.device)
neg_inf = torch.full(
(1,), float("-inf"), dtype=pose_tokens.dtype, device=pose_tokens.device
)
attn_mask = torch.where(mask_bool, zero, neg_inf)[:, None, :, :]
if kv_cache_list is not None and len(kv_cache_list) > 0 and S == 1:
k_cache = kv_cache_list[0][0][0]
if k_cache is not None:
cache_len = k_cache.shape[2]
curr_idx = len(self._frame_info) - 1
curr_is_kf, curr_ref_idx = self._frame_info[curr_idx]
cache_mask_vals = []
visible_frames = []
for i in range(max(0, curr_idx - cache_len), curr_idx):
cache_is_kf, cache_ref_idx = self._frame_info[i]
if curr_is_kf:
prev_kf_idx = None
for j in range(curr_idx - 1, -1, -1):
if self._frame_info[j][0]:
prev_kf_idx = j
break
if prev_kf_idx is not None:
can_see = i >= prev_kf_idx
else:
can_see = True
else:
is_ref = i == curr_ref_idx
same_ref = cache_ref_idx == curr_ref_idx
can_see = is_ref or same_ref
mask_val = zero if can_see else neg_inf
cache_mask_vals.append(mask_val)
if can_see:
visible_frames.append(i)
if len(cache_mask_vals) > 0:
cache_mask = torch.stack(cache_mask_vals, dim=0).view(
1, 1, 1, len(cache_mask_vals)
)
cache_mask = cache_mask.expand(B, 1, S, len(cache_mask_vals))
attn_mask = torch.cat([cache_mask, attn_mask], dim=-1)
else:
attn_mask = self._create_attn_mask(
S, mode, pose_tokens.dtype, pose_tokens.device
)
pred_pose_enc_list = []
pred_pose_enc = None
for iter_idx in range(num_iterations):
if pred_pose_enc is None:
module_input = self.embed_pose(self.empty_pose_tokens.expand(B, S, -1))
else:
module_input = self.embed_pose(pred_pose_enc)
shift_msa, scale_msa, gate_msa = self.poseLN_modulation(module_input).chunk(
3, dim=-1
)
pose_tokens_modulated = gate_msa * modulate(
self.adaln_norm(pose_tokens), shift_msa, scale_msa
)
pose_tokens_modulated = pose_tokens_modulated + pose_tokens
for i in range(self.trunk_depth):
if (
kv_cache_list is not None
and iter_idx < len(kv_cache_list)
and i < len(kv_cache_list[iter_idx])
):
pose_tokens_modulated, kv_cache_list[iter_idx][i] = self.trunk[i](
pose_tokens_modulated,
attn_mask=attn_mask,
kv_cache=kv_cache_list[iter_idx][i],
)
else:
pose_tokens_modulated = self.trunk[i](
pose_tokens_modulated, attn_mask=attn_mask
)
trunk_norm_output = self.trunk_norm(pose_tokens_modulated)
pred_pose_enc_delta = self.pose_branch(trunk_norm_output)
if pred_pose_enc is None:
pred_pose_enc = pred_pose_enc_delta
else:
pred_pose_enc = pred_pose_enc + pred_pose_enc_delta
activated_pose = activate_pose(
pred_pose_enc,
trans_act=self.trans_act,
quat_act=self.quat_act,
fl_act=self.fl_act,
)
pred_pose_enc_list.append(activated_pose)
final_pose_enc = pred_pose_enc_list[-1]
if final_pose_enc.dtype != torch.float32:
final_pose_enc = final_pose_enc.float()
result = {
"pose_enc": final_pose_enc,
"is_keyframe": is_keyframe,
"keyframe_indices": keyframe_indices,
}
if self.training and len(pred_pose_enc_list) > 0:
result["pose_enc_list"] = pred_pose_enc_list
if compute_switch_poses:
switch_poses = self._compute_switch_poses(
pred_pose_enc_list[-1], keyframe_indices, is_keyframe
)
result["switch_poses"] = switch_poses
if self.use_global_scale:
result["global_scale"] = self.global_scale.expand(B, 1)
if kv_cache_list is not None:
result["kv_cache_list"] = kv_cache_list
return result
def _compute_switch_poses(self, poses, keyframe_indices, is_keyframe):
"""
Compute T_{k'←k} for keyframe switches.
Returns a dictionary mapping (k, k') pairs to the relative transformation.
"""
B, S, _ = poses.shape
switch_poses = {}
for b in range(B):
prev_kf_idx = None
for s in range(S):
if is_keyframe[b, s]:
if prev_kf_idx is not None:
key = (b, prev_kf_idx, s)
switch_poses[key] = poses[b, s].clone()
prev_kf_idx = s
return switch_poses