Argus / argus /models /aggregator.py
lixi042
Initial commit: Argus metric panoramic 3D reconstruction demo
510e990
Raw
History Blame Contribute Delete
20.7 kB
import logging
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
from typing import Optional, Tuple, Union, List, Dict, Any
from argus.layers import Mlp
from argus.layers import PatchEmbed
from argus.layers.block import Block
from argus.layers.rope import RotaryPositionEmbedding2D, PositionGetter
from argus.layers.vision_transformer import vit_small, vit_base, vit_large, vit_giant2
from argus.heads.utils import reorder_by_reference
logger = logging.getLogger(__name__)
_RESNET_MEAN = [0.485, 0.456, 0.406]
_RESNET_STD = [0.229, 0.224, 0.225]
class Aggregator(nn.Module):
"""
Args:
img_size (int): Image size in pixels.
patch_size (int): Size of each patch for PatchEmbed.
embed_dim (int): Dimension of the token embeddings.
depth (int): Number of blocks.
num_heads (int): Number of attention heads.
mlp_ratio (float): Ratio of MLP hidden dim to embedding dim.
num_register_tokens (int): Number of register tokens.
block_fn (nn.Module): The block type used for attention (Block by default).
qkv_bias (bool): Whether to include bias in QKV projections.
proj_bias (bool): Whether to include bias in the output projection.
ffn_bias (bool): Whether to include bias in MLP layers.
patch_embed (str): Type of patch embed. e.g., "conv" or "dinov2_vitl14_reg".
aa_order (list[str]): The order of alternating attention, e.g. ["frame", "global"].
aa_block_size (int): How many blocks to group under each attention type before switching. If not necessary, set to 1.
qk_norm (bool): Whether to apply QK normalization.
rope_freq (int): Base frequency for rotary embedding. -1 to disable.
init_values (float): Init scale for layer scale.
reorder_by_learning_ref (bool): Whether to reorder features by learning reference view index.
ref_aa_block_num (int): Number of aa blocks for reference view learning.
"""
def __init__(
self,
img_size=518,
patch_size=14,
embed_dim=1024,
depth=24,
num_heads=16,
mlp_ratio=4.0,
num_register_tokens=4,
block_fn=Block,
qkv_bias=True,
proj_bias=True,
ffn_bias=True,
patch_embed="dinov2_vitl14_reg",
aa_order=["frame", "global"],
aa_block_size=1,
qk_norm=True,
rope_freq=100,
init_values=0.01,
reorder_by_learning_ref=True,
ref_aa_block_num=2,
save_inference_memory=True,
):
super().__init__()
self.reorder_by_learning_ref = reorder_by_learning_ref
self.save_inference_memory = save_inference_memory
self.__build_patch_embed__(patch_embed, img_size, patch_size, num_register_tokens, embed_dim=embed_dim)
# Initialize rotary position embedding if frequency > 0
self.rope = RotaryPositionEmbedding2D(frequency=rope_freq) if rope_freq > 0 else None
self.position_getter = PositionGetter() if self.rope is not None else None
self.frame_blocks = nn.ModuleList(
[
block_fn(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
proj_bias=proj_bias,
ffn_bias=ffn_bias,
init_values=init_values,
qk_norm=qk_norm,
rope=self.rope,
)
for _ in range(depth)
]
)
self.global_blocks = nn.ModuleList(
[
block_fn(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
proj_bias=proj_bias,
ffn_bias=ffn_bias,
init_values=init_values,
qk_norm=qk_norm,
rope=self.rope,
)
for _ in range(depth)
]
)
self.depth = depth
self.aa_order = aa_order
self.patch_size = patch_size
self.aa_block_size = aa_block_size
# Validate that depth is divisible by aa_block_size
if self.depth % self.aa_block_size != 0:
raise ValueError(f"depth ({depth}) must be divisible by aa_block_size ({aa_block_size})")
self.aa_block_num = self.depth // self.aa_block_size
# Reference Learning Network
if self.reorder_by_learning_ref:
self.ref_aa_block_num = ref_aa_block_num
self.ref_frame_blocks = nn.ModuleList(
[
block_fn(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
proj_bias=proj_bias,
ffn_bias=ffn_bias,
init_values=init_values,
qk_norm=qk_norm,
rope=self.rope,
)
for _ in range(self.ref_aa_block_num)
]
)
self.ref_global_blocks = nn.ModuleList(
[
block_fn(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
proj_bias=proj_bias,
ffn_bias=ffn_bias,
init_values=init_values,
qk_norm=qk_norm,
rope=self.rope,
)
for _ in range(self.ref_aa_block_num)
]
)
# Note: We have two camera tokens, one for the first frame and one for the rest
# The same applies for register tokens
self.camera_token = nn.Parameter(torch.randn(1, 2, 1, embed_dim))
self.register_token = nn.Parameter(torch.randn(1, 2, num_register_tokens, embed_dim))
if self.reorder_by_learning_ref:
# describe the covisibility of the current frame with other frames
self.covisibility_token = nn.Parameter(torch.randn(1, 1, 1, embed_dim))
# The patch tokens start after the camera and register tokens
self.patch_start_idx = 1 + num_register_tokens
# Initialize parameters with small values
nn.init.normal_(self.camera_token, std=1e-6)
nn.init.normal_(self.register_token, std=1e-6)
if self.reorder_by_learning_ref:
nn.init.normal_(self.covisibility_token, std=1e-6)
# Register normalization constants as buffers
for name, value in (("_resnet_mean", _RESNET_MEAN), ("_resnet_std", _RESNET_STD)):
self.register_buffer(name, torch.FloatTensor(value).view(1, 1, 3, 1, 1), persistent=False)
self.use_reentrant = False # hardcoded to False
def __build_patch_embed__(
self,
patch_embed,
img_size,
patch_size,
num_register_tokens,
interpolate_antialias=True,
interpolate_offset=0.0,
block_chunks=0,
init_values=1.0,
embed_dim=1024,
):
"""
Build the patch embed layer. If 'conv', we use a
simple PatchEmbed conv layer. Otherwise, we use a vision transformer.
"""
if "conv" in patch_embed:
self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=3, embed_dim=embed_dim)
else:
vit_models = {
"dinov2_vitl14_reg": vit_large,
"dinov2_vitb14_reg": vit_base,
"dinov2_vits14_reg": vit_small,
"dinov2_vitg2_reg": vit_giant2,
}
self.patch_embed = vit_models[patch_embed](
img_size=img_size,
patch_size=patch_size,
num_register_tokens=num_register_tokens,
interpolate_antialias=interpolate_antialias,
interpolate_offset=interpolate_offset,
block_chunks=block_chunks,
init_values=init_values,
)
# Disable gradient updates for mask token
if hasattr(self.patch_embed, "mask_token"):
# self.patch_embed.mask_token.requires_grad_(False)
del self.patch_embed.mask_token
# covisibility head
if self.reorder_by_learning_ref:
self.token_norm = nn.LayerNorm(embed_dim * 2)
self.covisibility_head = Mlp(in_features=embed_dim * 2, hidden_features=embed_dim * 2 // 2, out_features=1, drop=0)
def forward(self, images: torch.Tensor) -> Tuple[List[torch.Tensor], int]:
"""
Args:
images (torch.Tensor): Input images with shape [B, S, 3, H, W], in range [0, 1].
B: batch size, S: sequence length, 3: RGB channels, H: height, W: width
Returns:
(list[torch.Tensor], int):
The list of outputs from the attention blocks,
and the patch_start_idx indicating where patch tokens begin.
"""
B, S, C_in, H, W = images.shape
if C_in != 3:
raise ValueError(f"Expected 3 input channels, got {C_in}")
# Normalize images and reshape for patch embed
images = (images - self._resnet_mean) / self._resnet_std
# Reshape to [B*S, C, H, W] for patch embedding
images = images.view(B * S, C_in, H, W)
patch_tokens = self.patch_embed(images)
if isinstance(patch_tokens, dict):
patch_tokens = patch_tokens["x_norm_patchtokens"]
_, P, C = patch_tokens.shape
################# ref learning
covisibility_scores = None
ref_idx = None
if self.reorder_by_learning_ref:
# expand covisibility token to match batch size and sequence length
covisibility_token = self.covisibility_token.expand(B, S, 1, C).view(B * S, 1, C).contiguous()
# Concatenate covisibility token with patch tokens
covisibility_patch_tokens = torch.cat([covisibility_token, patch_tokens], dim=1) # [BS,1+HW,C]
covisibility_pos = None
if self.rope is not None:
covisibility_pos = self.position_getter(B * S, H // self.patch_size, W // self.patch_size, device=images.device)
# do not use position embedding for special covisibility_token
# so set pos to 0 for the special tokens
covisibility_pos = covisibility_pos + 1
covisibility_pos_special = torch.zeros(B * S, 1, 2).to(images.device).to(covisibility_pos.dtype)
covisibility_pos = torch.cat([covisibility_pos_special, covisibility_pos], dim=1) # [BS, 1+HW, 2]
# update P because we added special tokens
_, P_covis, C_covis = covisibility_patch_tokens.shape
frame_idx = 0
global_idx = 0
output_list = []
for ref_block_i in range(self.ref_aa_block_num):
for attn_type in self.aa_order:
if attn_type == "frame":
covisibility_patch_tokens, frame_idx, frame_intermediates = self._ref_process_frame_attention(
covisibility_patch_tokens, B, S, P_covis, C_covis, frame_idx, pos=covisibility_pos
)
elif attn_type == "global":
covisibility_patch_tokens, global_idx, global_intermediates = self._ref_process_global_attention(
covisibility_patch_tokens, B, S, P_covis, C_covis, global_idx, pos=covisibility_pos
)
else:
raise ValueError(f"Unknown attention type: {attn_type}")
for i in range(len(frame_intermediates)):
# concat frame and global intermediates, [B x S x P x 2C]
concat_inter = torch.cat([frame_intermediates[-1], global_intermediates[-1]], dim=-1)
output_list.append(concat_inter)
last_covisibility_patch_tokens = output_list[-1][:,:,0,:] # [B, S, C]
# normalize
last_covisibility_patch_tokens = self.token_norm(last_covisibility_patch_tokens)
covisibility_scores = self.covisibility_head(last_covisibility_patch_tokens).squeeze(-1) # [B, S]
# # cos
# feat_norm = F.normalize(covisibility_features, p=2, dim=-1, eps=1e-8) # [B, S, D]
# covisibility_scores = feat_norm @ feat_norm.transpose(-1, -2)
ref_idx = covisibility_scores.argmax(-1) # [B, S] -> [B]
patch_tokens = patch_tokens.view(B,S,P,C)
patch_tokens = reorder_by_reference(patch_tokens, ref_idx)
patch_tokens = patch_tokens.view(B*S,P,C).contiguous()
####################
# Expand camera and register tokens to match batch size and sequence length
camera_token = slice_expand_and_flatten(self.camera_token, B, S)
register_token = slice_expand_and_flatten(self.register_token, B, S)
# Concatenate special tokens with patch tokens
tokens = torch.cat([camera_token, register_token, patch_tokens], dim=1) # [BS,1+4+HW,C]
pos = None
if self.rope is not None:
pos = self.position_getter(B * S, H // self.patch_size, W // self.patch_size, device=images.device)
if self.patch_start_idx > 0:
# do not use position embedding for special tokens (camera and register tokens)
# so set pos to 0 for the special tokens
pos = pos + 1
pos_special = torch.zeros(B * S, self.patch_start_idx, 2).to(images.device).to(pos.dtype)
pos = torch.cat([pos_special, pos], dim=1) # [BS, 1+4+HW, 2]
# update P because we added special tokens
_, P, C = tokens.shape
frame_idx = 0
global_idx = 0
output_list = []
for block_i in range(self.aa_block_num):
for attn_type in self.aa_order:
if attn_type == "frame":
tokens, frame_idx, frame_intermediates = self._process_frame_attention(
tokens, B, S, P, C, frame_idx, pos=pos
)
elif attn_type == "global":
tokens, global_idx, global_intermediates = self._process_global_attention(
tokens, B, S, P, C, global_idx, pos=pos
)
else:
raise ValueError(f"Unknown attention type: {attn_type}")
for i in range(len(frame_intermediates)):
concat_inter = torch.cat([frame_intermediates[i], global_intermediates[i]], dim=-1)
if (not self.training ) and (self.save_inference_memory) and (block_i not in [4,11,17,23]):
# only save the useful indices of intermediates
output_list.append(torch.tensor(0))
else:
# concat frame and global intermediates, [B x S x P x 2C]
output_list.append(concat_inter)
del concat_inter
del frame_intermediates
del global_intermediates
return output_list, self.patch_start_idx, covisibility_scores, ref_idx
def _process_frame_attention(self, tokens, B, S, P, C, frame_idx, pos=None):
"""
Process frame attention blocks. We keep tokens in shape (B*S, P, C).
"""
# If needed, reshape tokens or positions:
if tokens.shape != (B * S, P, C):
tokens = tokens.view(B, S, P, C).view(B * S, P, C)
if pos is not None and pos.shape != (B * S, P, 2):
pos = pos.view(B, S, P, 2).view(B * S, P, 2)
intermediates = []
# by default, self.aa_block_size=1, which processes one block at a time
for _ in range(self.aa_block_size):
if self.training:
tokens = checkpoint(self.frame_blocks[frame_idx], tokens, pos, use_reentrant=self.use_reentrant)
else:
tokens = self.frame_blocks[frame_idx](tokens, pos=pos)
frame_idx += 1
intermediates.append(tokens.view(B, S, P, C))
return tokens, frame_idx, intermediates
def _process_global_attention(self, tokens, B, S, P, C, global_idx, pos=None):
"""
Process global attention blocks. We keep tokens in shape (B, S*P, C).
"""
if tokens.shape != (B, S * P, C):
tokens = tokens.view(B, S, P, C).view(B, S * P, C)
if pos is not None and pos.shape != (B, S * P, 2):
pos = pos.view(B, S, P, 2).view(B, S * P, 2)
intermediates = []
# by default, self.aa_block_size=1, which processes one block at a time
for _ in range(self.aa_block_size):
if self.training:
tokens = checkpoint(self.global_blocks[global_idx], tokens, pos, use_reentrant=self.use_reentrant)
else:
tokens = self.global_blocks[global_idx](tokens, pos=pos)
global_idx += 1
intermediates.append(tokens.view(B, S, P, C))
return tokens, global_idx, intermediates
def _ref_process_frame_attention(self, tokens, B, S, P, C, frame_idx, pos=None):
"""
Process frame attention blocks. We keep tokens in shape (B*S, P, C).
"""
# If needed, reshape tokens or positions:
if tokens.shape != (B * S, P, C):
tokens = tokens.view(B, S, P, C).view(B * S, P, C)
if pos is not None and pos.shape != (B * S, P, 2):
pos = pos.view(B, S, P, 2).view(B * S, P, 2)
intermediates = []
# by default, self.aa_block_size=1, which processes one block at a time
for _ in range(self.aa_block_size):
if self.training:
tokens = checkpoint(self.ref_frame_blocks[frame_idx], tokens, pos, use_reentrant=self.use_reentrant)
else:
tokens = self.ref_frame_blocks[frame_idx](tokens, pos=pos)
frame_idx += 1
intermediates.append(tokens.view(B, S, P, C))
return tokens, frame_idx, intermediates
def _ref_process_global_attention(self, tokens, B, S, P, C, global_idx, pos=None):
"""
Process global attention blocks. We keep tokens in shape (B, S*P, C).
"""
if tokens.shape != (B, S * P, C):
tokens = tokens.view(B, S, P, C).view(B, S * P, C)
if pos is not None and pos.shape != (B, S * P, 2):
pos = pos.view(B, S, P, 2).view(B, S * P, 2)
intermediates = []
# by default, self.aa_block_size=1, which processes one block at a time
for _ in range(self.aa_block_size):
if self.training:
tokens = checkpoint(self.ref_global_blocks[global_idx], tokens, pos, use_reentrant=self.use_reentrant)
else:
tokens = self.ref_global_blocks[global_idx](tokens, pos=pos)
global_idx += 1
intermediates.append(tokens.view(B, S, P, C))
return tokens, global_idx, intermediates
def slice_expand_and_flatten(token_tensor, B, S):
"""
Processes specialized tokens with shape (1, 2, X, C) for multi-frame processing:
1) Uses the first position (index=0) for the first frame only
2) Uses the second position (index=1) for all remaining frames (S-1 frames)
3) Expands both to match batch size B
4) Concatenates to form (B, S, X, C) where each sequence has 1 first-position token
followed by (S-1) second-position tokens
5) Flattens to (B*S, X, C) for processing
Returns:
torch.Tensor: Processed tokens with shape (B*S, X, C)
"""
# Slice out the "query" tokens => shape (1, 1, ...)
query = token_tensor[:, 0:1, ...].expand(B, 1, *token_tensor.shape[2:])
# Slice out the "other" tokens => shape (1, S-1, ...)
others = token_tensor[:, 1:, ...].expand(B, S - 1, *token_tensor.shape[2:])
# Concatenate => shape (B, S, ...)
combined = torch.cat([query, others], dim=1)
# Finally flatten => shape (B*S, ...)
combined = combined.view(B * S, *combined.shape[2:])
return combined