Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE-VGGT file in the root directory of this source tree. | |
| import logging | |
| import torch | |
| from torch import nn, Tensor | |
| from torch.utils.checkpoint import checkpoint | |
| from typing import List, Callable | |
| from dataclasses import dataclass | |
| from einops import repeat | |
| from vggt.layers.block import drop_add_residual_stochastic_depth | |
| from vggt.layers.rope import RotaryPositionEmbedding2D, PositionGetter | |
| from vggt.layers.attention import Attention | |
| from vggt.layers.drop_path import DropPath | |
| from vggt.layers.layer_scale import LayerScale | |
| from vggt.layers.mlp import Mlp | |
| logger = logging.getLogger(__name__) | |
| class ModulationOut: | |
| shift: Tensor | |
| scale: Tensor | |
| gate: Tensor | |
| class Modulation(nn.Module): | |
| def __init__(self, dim: int, double: bool): | |
| super().__init__() | |
| self.is_double = double | |
| self.multiplier = 6 if double else 3 | |
| self.lin = nn.Linear(dim, self.multiplier * dim, bias=True) | |
| def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]: | |
| out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1) | |
| return ( | |
| ModulationOut(*out[:3]), | |
| ModulationOut(*out[3:]) if self.is_double else None, | |
| ) | |
| class ConditionalBlock(nn.Module): | |
| def __init__( | |
| self, | |
| dim: int, | |
| num_heads: int, | |
| mlp_ratio: float = 4.0, | |
| qkv_bias: bool = True, | |
| proj_bias: bool = True, | |
| ffn_bias: bool = True, | |
| drop: float = 0.0, | |
| attn_drop: float = 0.0, | |
| init_values=None, | |
| drop_path: float = 0.0, | |
| act_layer: Callable[..., nn.Module] = nn.GELU, | |
| norm_layer: Callable[..., nn.Module] = nn.LayerNorm, | |
| attn_class: Callable[..., nn.Module] = Attention, | |
| ffn_layer: Callable[..., nn.Module] = Mlp, | |
| qk_norm: bool = False, | |
| fused_attn: bool = True, # use F.scaled_dot_product_attention or not | |
| rope=None, | |
| ) -> None: | |
| super().__init__() | |
| self.norm1 = norm_layer(dim, elementwise_affine=False) | |
| self.modulation = Modulation(dim, double=False) | |
| self.attn = attn_class( | |
| dim, | |
| num_heads=num_heads, | |
| qkv_bias=qkv_bias, | |
| proj_bias=proj_bias, | |
| attn_drop=attn_drop, | |
| proj_drop=drop, | |
| qk_norm=qk_norm, | |
| fused_attn=fused_attn, | |
| rope=rope, | |
| ) | |
| self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() | |
| self.norm2 = norm_layer(dim) | |
| mlp_hidden_dim = int(dim * mlp_ratio) | |
| self.mlp = ffn_layer( | |
| in_features=dim, | |
| hidden_features=mlp_hidden_dim, | |
| act_layer=act_layer, | |
| drop=drop, | |
| bias=ffn_bias, | |
| ) | |
| self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() | |
| self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() | |
| self.sample_drop_ratio = drop_path | |
| def forward(self, x: Tensor, pos=None, cond=None, is_global=False) -> Tensor: | |
| B, S = cond.shape[:2] | |
| C = x.shape[-1] | |
| if is_global: | |
| P = x.shape[1] // S | |
| cond = cond.view(B * S, C) | |
| mod, _ = self.modulation(cond) | |
| def attn_residual_func(x: Tensor, pos=None) -> Tensor: | |
| """ | |
| conditional attention following DiT implementation from Flux | |
| https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py#L194-L239 | |
| """ | |
| def prepare_for_mod(y): | |
| """reshape to modulate the patch tokens with correct conditioning one""" | |
| return y.view(B, S, P, C).view(B * S, P, C) if is_global else y | |
| def restore_after_mod(y): | |
| """reshape back to global sequence""" | |
| return y.view(B, S, P, C).view(B, S * P, C) if is_global else y | |
| x = prepare_for_mod(x) | |
| x = (1 + mod.scale) * self.norm1(x) + mod.shift | |
| x = restore_after_mod(x) | |
| x = self.attn(x, pos=pos) | |
| x = prepare_for_mod(x) | |
| x = mod.gate * x | |
| x = restore_after_mod(x) | |
| return x | |
| def ffn_residual_func(x: Tensor) -> Tensor: | |
| return self.ls2(self.mlp(self.norm2(x))) | |
| if self.training and self.sample_drop_ratio > 0.1: | |
| # the overhead is compensated only for a drop path rate larger than 0.1 | |
| x = drop_add_residual_stochastic_depth( | |
| x, | |
| pos=pos, | |
| residual_func=attn_residual_func, | |
| sample_drop_ratio=self.sample_drop_ratio, | |
| ) | |
| x = drop_add_residual_stochastic_depth( | |
| x, | |
| residual_func=ffn_residual_func, | |
| sample_drop_ratio=self.sample_drop_ratio, | |
| ) | |
| elif self.training and self.sample_drop_ratio > 0.0: | |
| x = x + self.drop_path1(attn_residual_func(x, pos=pos)) | |
| x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2 | |
| else: | |
| x = x + attn_residual_func(x, pos=pos) | |
| x = x + ffn_residual_func(x) | |
| return x | |
| class Decoder(nn.Module): | |
| """Attention blocks after encoder per DPT input feature | |
| to generate point maps at a given time. | |
| """ | |
| def __init__( | |
| self, | |
| cfg, | |
| dim_in: int, | |
| intermediate_layer_idx: List[int] = [4, 11, 17, 23], | |
| patch_size=14, | |
| embed_dim=1024, | |
| depth=2, | |
| num_heads=16, | |
| mlp_ratio=4.0, | |
| block_fn=ConditionalBlock, | |
| qkv_bias=True, | |
| proj_bias=True, | |
| ffn_bias=True, | |
| aa_order=["frame", "global"], | |
| aa_block_size=1, | |
| qk_norm=True, | |
| rope_freq=100, | |
| init_values=0.01, | |
| ): | |
| super().__init__() | |
| self.cfg = cfg | |
| self.intermediate_layer_idx = intermediate_layer_idx | |
| self.depth = depth | |
| self.aa_order = aa_order | |
| self.patch_size = patch_size | |
| self.aa_block_size = aa_block_size | |
| # Validate that depth is divisible by aa_block_size | |
| if self.depth % self.aa_block_size != 0: | |
| raise ValueError(f"depth ({depth}) must be divisible by aa_block_size ({aa_block_size})") | |
| self.aa_block_num = self.depth // self.aa_block_size | |
| self.rope = ( | |
| RotaryPositionEmbedding2D(frequency=rope_freq) if rope_freq > 0 else None | |
| ) | |
| self.position_getter = PositionGetter() if self.rope is not None else None | |
| self.dim_in = dim_in | |
| self.old_decoder = False | |
| if self.old_decoder: | |
| self.frame_blocks = nn.ModuleList( | |
| [ | |
| block_fn( | |
| dim=embed_dim*2, | |
| num_heads=num_heads, | |
| mlp_ratio=mlp_ratio, | |
| qkv_bias=qkv_bias, | |
| proj_bias=proj_bias, | |
| ffn_bias=ffn_bias, | |
| init_values=init_values, | |
| qk_norm=qk_norm, | |
| rope=self.rope, | |
| ) | |
| for _ in range(depth) | |
| ] | |
| ) | |
| self.global_blocks = nn.ModuleList( | |
| [ | |
| block_fn( | |
| dim=embed_dim*2, | |
| num_heads=num_heads, | |
| mlp_ratio=mlp_ratio, | |
| qkv_bias=qkv_bias, | |
| proj_bias=proj_bias, | |
| ffn_bias=ffn_bias, | |
| init_values=init_values, | |
| qk_norm=qk_norm, | |
| rope=self.rope, | |
| ) | |
| for _ in range(depth) | |
| ] | |
| ) | |
| else: | |
| depths = [depth] | |
| self.frame_blocks = nn.ModuleList([ | |
| nn.ModuleList([ | |
| block_fn( | |
| dim=embed_dim*2, | |
| num_heads=num_heads, | |
| mlp_ratio=mlp_ratio, | |
| qkv_bias=qkv_bias, | |
| proj_bias=proj_bias, | |
| ffn_bias=ffn_bias, | |
| init_values=init_values, | |
| qk_norm=qk_norm, | |
| rope=self.rope, | |
| ) | |
| for _ in range(d) | |
| ]) | |
| for d in depths | |
| ]) | |
| self.global_blocks = nn.ModuleList([ | |
| nn.ModuleList([ | |
| block_fn( | |
| dim=embed_dim*2, | |
| num_heads=num_heads, | |
| mlp_ratio=mlp_ratio, | |
| qkv_bias=qkv_bias, | |
| proj_bias=proj_bias, | |
| ffn_bias=ffn_bias, | |
| init_values=init_values, | |
| qk_norm=qk_norm, | |
| rope=self.rope, | |
| ) | |
| for _ in range(d) | |
| ]) | |
| for d in depths | |
| ]) | |
| self.use_reentrant = False # hardcoded to False | |
| def get_condition_tokens( | |
| self, | |
| aggregated_tokens_list: List[torch.Tensor], | |
| cond_view_idxs: torch.Tensor | |
| ): | |
| # Use tokens from the last block for conditioning | |
| tokens_last = aggregated_tokens_list[-1] # [B S N_tok D] | |
| # Extract the camera tokens | |
| cond_token_idx = 1 | |
| camera_tokens = tokens_last[:, :, [cond_token_idx]] # [B S D] | |
| cond_view_idxs = cond_view_idxs.to(camera_tokens.device) | |
| cond_view_idxs = repeat( | |
| cond_view_idxs, | |
| "b s -> b s c d", | |
| c=camera_tokens.shape[2], | |
| d=camera_tokens.shape[3], | |
| ) | |
| cond_tokens = torch.gather(camera_tokens, 1, cond_view_idxs) | |
| return cond_tokens | |
| def forward( | |
| self, | |
| images: torch.Tensor, | |
| aggregated_tokens_list: List[torch.Tensor], | |
| patch_start_idx: int, | |
| cond_view_idxs: torch.Tensor, | |
| ): | |
| B, S, _, H, W = images.shape | |
| cond_tokens = self.get_condition_tokens( | |
| aggregated_tokens_list, cond_view_idxs | |
| ) | |
| input_tokens = [] | |
| for k, layer_idx in enumerate(self.intermediate_layer_idx): | |
| layer_tokens = aggregated_tokens_list[layer_idx].clone() | |
| input_tokens.append(layer_tokens) | |
| _, _, P, C = input_tokens[0].shape | |
| pos = None | |
| if self.rope is not None: | |
| pos = self.position_getter( | |
| B * S, H // self.patch_size, W // self.patch_size, device=images.device | |
| ) | |
| if patch_start_idx > 0: | |
| # do not use position embedding for special tokens (camera and register tokens) | |
| # so set pos to 0 for the special tokens | |
| pos = pos + 1 | |
| pos_special = torch.zeros(B * S, patch_start_idx, 2).to(images.device).to(pos.dtype) | |
| pos = torch.cat([pos_special, pos], dim=1) | |
| frame_idx = 0 | |
| global_idx = 0 | |
| depth = len(self.frame_blocks[0]) | |
| N = len(input_tokens) | |
| # stack all intermediate layer tokens along batch dimension | |
| # they are all processed by the same decoder | |
| s_tokens = torch.cat(input_tokens) | |
| s_cond_tokens = torch.cat([cond_tokens] * N, dim=0) | |
| s_pos = torch.cat([pos] * N, dim=0) | |
| # perform time conditioned attention | |
| for _ in range(depth): | |
| for attn_type in self.aa_order: | |
| token_idx = 0 | |
| if attn_type == "frame": | |
| s_tokens, frame_idx, _ = self._process_frame_attention( | |
| s_tokens, s_cond_tokens, B * N, S, P, C, frame_idx, pos=s_pos, token_idx=token_idx | |
| ) | |
| elif attn_type == "global": | |
| s_tokens, global_idx, _ = self._process_global_attention( | |
| s_tokens, s_cond_tokens, B * N, S, P, C, global_idx, pos=s_pos, token_idx=token_idx | |
| ) | |
| else: | |
| raise ValueError(f"Unknown attention type: {attn_type}") | |
| processed = [t.view(B, S, P, C) for t in s_tokens.split(B, dim=0)] | |
| return processed | |
| def _process_frame_attention(self, tokens, cond_tokens, B, S, P, C, frame_idx, pos=None, token_idx=0): | |
| """ | |
| Process frame attention blocks. We keep tokens in shape (B*S, P, C). | |
| """ | |
| # If needed, reshape tokens or positions: | |
| if tokens.shape != (B * S, P, C): | |
| tokens = tokens.view(B, S, P, C).view(B * S, P, C) | |
| if pos is not None and pos.shape != (B * S, P, 2): | |
| pos = pos.view(B, S, P, 2).view(B * S, P, 2) | |
| intermediates = [] | |
| # by default, self.aa_block_size=1, which processes one block at a time | |
| for _ in range(self.aa_block_size): | |
| if self.training: | |
| tokens = checkpoint(self.frame_blocks[token_idx][frame_idx], tokens, pos, cond_tokens, use_reentrant=self.use_reentrant) | |
| else: | |
| if self.old_decoder: | |
| tokens = self.frame_blocks[frame_idx](tokens, pos=pos, cond=cond_tokens) | |
| else: | |
| tokens = self.frame_blocks[0][frame_idx](tokens, pos=pos, cond=cond_tokens) | |
| frame_idx += 1 | |
| intermediates.append(tokens.view(B, S, P, C)) | |
| return tokens, frame_idx, intermediates | |
| def _process_global_attention(self, tokens, cond_tokens, B, S, P, C, global_idx, pos=None, token_idx=0): | |
| """ | |
| Process global attention blocks. We keep tokens in shape (B, S*P, C). | |
| """ | |
| if tokens.shape != (B, S * P, C): | |
| tokens = tokens.view(B, S, P, C).view(B, S * P, C) | |
| if pos is not None and pos.shape != (B, S * P, 2): | |
| pos = pos.view(B, S, P, 2).view(B, S * P, 2) | |
| intermediates = [] | |
| # by default, self.aa_block_size=1, which processes one block at a time | |
| for _ in range(self.aa_block_size): | |
| if self.training: | |
| tokens = checkpoint(self.global_blocks[token_idx][global_idx], tokens, pos, cond_tokens, True, use_reentrant=self.use_reentrant) | |
| else: | |
| if self.old_decoder: | |
| tokens = self.global_blocks[global_idx](tokens, pos=pos, cond=cond_tokens, is_global=True) | |
| else: | |
| tokens = self.global_blocks[0][global_idx](tokens, pos=pos, cond=cond_tokens, is_global=True) | |
| global_idx += 1 | |
| intermediates.append(tokens.view(B, S, P, C)) | |
| return tokens, global_idx, intermediates | |