| | |
| | |
| | |
| | |
| | |
| |
|
| | import logging |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from typing import Optional, Tuple, Union, List, Dict, Any |
| |
|
| | from src.model.encoder.vggt.layers import PatchEmbed |
| | from src.model.encoder.vggt.layers.block import Block |
| | from src.model.encoder.vggt.layers.rope import RotaryPositionEmbedding2D, PositionGetter |
| | from src.model.encoder.vggt.layers.vision_transformer import vit_small, vit_base, vit_large, vit_giant2 |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| | _RESNET_MEAN = [0.485, 0.456, 0.406] |
| | _RESNET_STD = [0.229, 0.224, 0.225] |
| |
|
| |
|
| | class Aggregator(nn.Module): |
| | """ |
| | The Aggregator applies alternating-attention over input frames, |
| | as described in VGGT: Visual Geometry Grounded Transformer. |
| | |
| | |
| | Args: |
| | img_size (int): Image size in pixels. |
| | patch_size (int): Size of each patch for PatchEmbed. |
| | embed_dim (int): Dimension of the token embeddings. |
| | depth (int): Number of blocks. |
| | num_heads (int): Number of attention heads. |
| | mlp_ratio (float): Ratio of MLP hidden dim to embedding dim. |
| | num_register_tokens (int): Number of register tokens. |
| | block_fn (nn.Module): The block type used for attention (Block by default). |
| | qkv_bias (bool): Whether to include bias in QKV projections. |
| | proj_bias (bool): Whether to include bias in the output projection. |
| | ffn_bias (bool): Whether to include bias in MLP layers. |
| | patch_embed (str): Type of patch embed. e.g., "conv" or "dinov2_vitl14_reg". |
| | aa_order (list[str]): The order of alternating attention, e.g. ["frame", "global"]. |
| | aa_block_size (int): How many blocks to group under each attention type before switching. If not necessary, set to 1. |
| | qk_norm (bool): Whether to apply QK normalization. |
| | rope_freq (int): Base frequency for rotary embedding. -1 to disable. |
| | init_values (float): Init scale for layer scale. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | img_size=518, |
| | patch_size=14, |
| | embed_dim=1024, |
| | depth=24, |
| | num_heads=16, |
| | mlp_ratio=4.0, |
| | num_register_tokens=4, |
| | block_fn=Block, |
| | qkv_bias=True, |
| | proj_bias=True, |
| | ffn_bias=True, |
| | patch_embed="dinov2_vitl14_reg", |
| | aa_order=["frame", "global"], |
| | aa_block_size=1, |
| | qk_norm=True, |
| | rope_freq=100, |
| | init_values=0.01, |
| | ): |
| | super().__init__() |
| | self.use_checkpoint = True |
| | self.__build_patch_embed__(patch_embed, img_size, patch_size, num_register_tokens, embed_dim=embed_dim) |
| |
|
| | |
| | self.rope = RotaryPositionEmbedding2D(frequency=rope_freq) if rope_freq > 0 else None |
| | self.position_getter = PositionGetter() if self.rope is not None else None |
| | |
| | self.frame_blocks = nn.ModuleList( |
| | [ |
| | block_fn( |
| | dim=embed_dim, |
| | num_heads=num_heads, |
| | mlp_ratio=mlp_ratio, |
| | qkv_bias=qkv_bias, |
| | proj_bias=proj_bias, |
| | ffn_bias=ffn_bias, |
| | init_values=init_values, |
| | qk_norm=qk_norm, |
| | rope=self.rope, |
| | ) |
| | for _ in range(depth) |
| | ] |
| | ) |
| |
|
| | self.global_blocks = nn.ModuleList( |
| | [ |
| | block_fn( |
| | dim=embed_dim, |
| | num_heads=num_heads, |
| | mlp_ratio=mlp_ratio, |
| | qkv_bias=qkv_bias, |
| | proj_bias=proj_bias, |
| | ffn_bias=ffn_bias, |
| | init_values=init_values, |
| | qk_norm=qk_norm, |
| | rope=self.rope, |
| | ) |
| | for _ in range(depth) |
| | ] |
| | ) |
| |
|
| | self.depth = depth |
| | self.aa_order = aa_order |
| | self.patch_size = patch_size |
| | self.aa_block_size = aa_block_size |
| |
|
| | |
| | if self.depth % self.aa_block_size != 0: |
| | raise ValueError(f"depth ({depth}) must be divisible by aa_block_size ({aa_block_size})") |
| |
|
| | self.aa_block_num = self.depth // self.aa_block_size |
| |
|
| | |
| | |
| | self.camera_token = nn.Parameter(torch.randn(1, 2, 1, embed_dim)) |
| | self.register_token = nn.Parameter(torch.randn(1, 2, num_register_tokens, embed_dim)) |
| |
|
| | |
| | self.patch_start_idx = 1 + num_register_tokens |
| |
|
| | |
| | nn.init.normal_(self.camera_token, std=1e-6) |
| | nn.init.normal_(self.register_token, std=1e-6) |
| |
|
| | |
| | for name, value in ( |
| | ("_resnet_mean", _RESNET_MEAN), |
| | ("_resnet_std", _RESNET_STD), |
| | ): |
| | self.register_buffer( |
| | name, |
| | torch.FloatTensor(value).view(1, 1, 3, 1, 1), |
| | persistent=False, |
| | ) |
| | |
| | def __build_patch_embed__( |
| | self, |
| | patch_embed, |
| | img_size, |
| | patch_size, |
| | num_register_tokens, |
| | interpolate_antialias=True, |
| | interpolate_offset=0.0, |
| | block_chunks=0, |
| | init_values=1.0, |
| | embed_dim=1024, |
| | ): |
| | """ |
| | Build the patch embed layer. If 'conv', we use a |
| | simple PatchEmbed conv layer. Otherwise, we use a vision transformer. |
| | """ |
| | |
| | if "conv" in patch_embed: |
| | self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=3, embed_dim=embed_dim) |
| | else: |
| | vit_models = { |
| | "dinov2_vitl14_reg": vit_large, |
| | "dinov2_vitb14_reg": vit_base, |
| | "dinov2_vits14_reg": vit_small, |
| | "dinov2_vitg2_reg": vit_giant2, |
| | } |
| |
|
| | self.patch_embed = vit_models[patch_embed]( |
| | img_size=img_size, |
| | patch_size=patch_size, |
| | num_register_tokens=num_register_tokens, |
| | interpolate_antialias=interpolate_antialias, |
| | interpolate_offset=interpolate_offset, |
| | block_chunks=block_chunks, |
| | init_values=init_values, |
| | ) |
| |
|
| | |
| | if hasattr(self.patch_embed, "mask_token"): |
| | self.patch_embed.mask_token.requires_grad_(False) |
| |
|
| | def forward( |
| | self, |
| | images: torch.Tensor, |
| | intermediate_layer_idx: Optional[List[int]] = None |
| | ) -> Tuple[List[torch.Tensor], int]: |
| | """ |
| | Args: |
| | images (torch.Tensor): Input images with shape [B, S, 3, H, W], in range [0, 1]. |
| | B: batch size, S: sequence length, 3: RGB channels, H: height, W: width |
| | |
| | Returns: |
| | (list[torch.Tensor], int): |
| | The list of outputs from the attention blocks, |
| | and the patch_start_idx indicating where patch tokens begin. |
| | """ |
| | B, S, C_in, H, W = images.shape |
| |
|
| | if C_in != 3: |
| | raise ValueError(f"Expected 3 input channels, got {C_in}") |
| | |
| | |
| | images = (images - self._resnet_mean) / self._resnet_std |
| |
|
| | |
| | images = images.view(B * S, C_in, H, W) |
| | patch_tokens = self.patch_embed(images) |
| |
|
| | if isinstance(patch_tokens, dict): |
| | patch_tokens = patch_tokens["x_norm_patchtokens"] |
| |
|
| | _, P, C = patch_tokens.shape |
| |
|
| | |
| | camera_token = slice_expand_and_flatten(self.camera_token, B, S) |
| | register_token = slice_expand_and_flatten(self.register_token, B, S) |
| |
|
| | |
| | tokens = torch.cat([camera_token, register_token, patch_tokens], dim=1) |
| |
|
| | pos = None |
| | if self.rope is not None: |
| | pos = self.position_getter(B * S, H // self.patch_size, W // self.patch_size, device=images.device) |
| |
|
| | if self.patch_start_idx > 0: |
| | |
| | |
| | pos = pos + 1 |
| | pos_special = torch.zeros(B * S, self.patch_start_idx, 2).to(images.device).to(pos.dtype) |
| | pos = torch.cat([pos_special, pos], dim=1) |
| |
|
| | |
| | _, P, C = tokens.shape |
| |
|
| | frame_idx = 0 |
| | global_idx = 0 |
| | output_list = [] |
| | layer_idx = 0 |
| | |
| | |
| | if intermediate_layer_idx is not None: |
| | required_layers = set(intermediate_layer_idx) |
| | |
| | required_layers.add(self.depth - 1) |
| |
|
| | for _ in range(self.aa_block_num): |
| | for attn_type in self.aa_order: |
| | if attn_type == "frame": |
| | tokens, frame_idx, frame_intermediates = self._process_frame_attention( |
| | tokens, B, S, P, C, frame_idx, pos=pos |
| | ) |
| | elif attn_type == "global": |
| | tokens, global_idx, global_intermediates = self._process_global_attention( |
| | tokens, B, S, P, C, global_idx, pos=pos |
| | ) |
| | else: |
| | raise ValueError(f"Unknown attention type: {attn_type}") |
| |
|
| | if intermediate_layer_idx is not None: |
| | for i in range(len(frame_intermediates)): |
| | current_layer = layer_idx + i |
| | if current_layer in required_layers: |
| | |
| | concat_inter = torch.cat([frame_intermediates[i], global_intermediates[i]], dim=-1) |
| | output_list.append(concat_inter) |
| | layer_idx += self.aa_block_size |
| | |
| | else: |
| | for i in range(len(frame_intermediates)): |
| | |
| | concat_inter = torch.cat([frame_intermediates[i], global_intermediates[i]], dim=-1) |
| | output_list.append(concat_inter) |
| | |
| | del concat_inter |
| | del frame_intermediates |
| | del global_intermediates |
| | return output_list, self.patch_start_idx |
| |
|
| | def _process_frame_attention(self, tokens, B, S, P, C, frame_idx, pos=None): |
| | """ |
| | Process frame attention blocks. We keep tokens in shape (B*S, P, C). |
| | """ |
| | |
| | if tokens.shape != (B * S, P, C): |
| | tokens = tokens.view(B, S, P, C).view(B * S, P, C) |
| |
|
| | if pos is not None and pos.shape != (B * S, P, 2): |
| | pos = pos.view(B, S, P, 2).view(B * S, P, 2) |
| |
|
| | intermediates = [] |
| | |
| | |
| | for _ in range(self.aa_block_size): |
| | if self.use_checkpoint: |
| | tokens = torch.utils.checkpoint.checkpoint( |
| | self.frame_blocks[frame_idx], |
| | tokens, |
| | pos, |
| | use_reentrant=False, |
| | ) |
| | else: |
| | tokens = self.frame_blocks[frame_idx](tokens, pos=pos) |
| | frame_idx += 1 |
| | intermediates.append(tokens.view(B, S, P, C)) |
| |
|
| | return tokens, frame_idx, intermediates |
| |
|
| | def _process_global_attention(self, tokens, B, S, P, C, global_idx, pos=None): |
| | """ |
| | Process global attention blocks. We keep tokens in shape (B, S*P, C). |
| | """ |
| | if tokens.shape != (B, S * P, C): |
| | tokens = tokens.view(B, S, P, C).view(B, S * P, C) |
| |
|
| | if pos is not None and pos.shape != (B, S * P, 2): |
| | pos = pos.view(B, S, P, 2).view(B, S * P, 2) |
| |
|
| | intermediates = [] |
| | |
| | |
| | for _ in range(self.aa_block_size): |
| | if self.use_checkpoint: |
| | tokens = torch.utils.checkpoint.checkpoint( |
| | self.global_blocks[global_idx], |
| | tokens, |
| | pos, |
| | use_reentrant=False, |
| | ) |
| | else: |
| | tokens = self.global_blocks[global_idx](tokens, pos=pos) |
| | global_idx += 1 |
| | intermediates.append(tokens.view(B, S, P, C)) |
| |
|
| | return tokens, global_idx, intermediates |
| |
|
| |
|
| | def slice_expand_and_flatten(token_tensor, B, S): |
| | """ |
| | Processes specialized tokens with shape (1, 2, X, C) for multi-frame processing: |
| | 1) Uses the first position (index=0) for the first frame only |
| | 2) Uses the second position (index=1) for all remaining frames (S-1 frames) |
| | 3) Expands both to match batch size B |
| | 4) Concatenates to form (B, S, X, C) where each sequence has 1 first-position token |
| | followed by (S-1) second-position tokens |
| | 5) Flattens to (B*S, X, C) for processing |
| | |
| | Returns: |
| | torch.Tensor: Processed tokens with shape (B*S, X, C) |
| | """ |
| |
|
| | |
| | query = token_tensor[:, 0:1, ...].expand(B, 1, *token_tensor.shape[2:]) |
| | |
| | others = token_tensor[:, 1:, ...].expand(B, S - 1, *token_tensor.shape[2:]) |
| | |
| | combined = torch.cat([query, others], dim=1) |
| |
|
| | |
| | combined = combined.view(B * S, *combined.shape[2:]) |
| | return combined |
| |
|