| | |
| | |
| | |
| | |
| | |
| |
|
| | import logging |
| | import torch |
| | from torch import nn, Tensor |
| | from torch.utils.checkpoint import checkpoint |
| | from typing import List, Callable |
| | from dataclasses import dataclass |
| |
|
| | from einops import repeat |
| |
|
| | from vggt.layers.block import drop_add_residual_stochastic_depth |
| | from vggt.layers.rope import RotaryPositionEmbedding2D, PositionGetter |
| |
|
| | from vggt.layers.attention import Attention |
| | from vggt.layers.drop_path import DropPath |
| | from vggt.layers.layer_scale import LayerScale |
| | from vggt.layers.mlp import Mlp |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | @dataclass |
| | class ModulationOut: |
| | shift: Tensor |
| | scale: Tensor |
| | gate: Tensor |
| |
|
| |
|
| | class Modulation(nn.Module): |
| | def __init__(self, dim: int, double: bool): |
| | super().__init__() |
| | self.is_double = double |
| | self.multiplier = 6 if double else 3 |
| | self.lin = nn.Linear(dim, self.multiplier * dim, bias=True) |
| |
|
| | def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]: |
| | out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1) |
| |
|
| | return ( |
| | ModulationOut(*out[:3]), |
| | ModulationOut(*out[3:]) if self.is_double else None, |
| | ) |
| |
|
| |
|
| | class ConditionalBlock(nn.Module): |
| | def __init__( |
| | self, |
| | dim: int, |
| | num_heads: int, |
| | mlp_ratio: float = 4.0, |
| | qkv_bias: bool = True, |
| | proj_bias: bool = True, |
| | ffn_bias: bool = True, |
| | drop: float = 0.0, |
| | attn_drop: float = 0.0, |
| | init_values=None, |
| | drop_path: float = 0.0, |
| | act_layer: Callable[..., nn.Module] = nn.GELU, |
| | norm_layer: Callable[..., nn.Module] = nn.LayerNorm, |
| | attn_class: Callable[..., nn.Module] = Attention, |
| | ffn_layer: Callable[..., nn.Module] = Mlp, |
| | qk_norm: bool = False, |
| | fused_attn: bool = True, |
| | rope=None, |
| | ) -> None: |
| | super().__init__() |
| |
|
| | self.norm1 = norm_layer(dim, elementwise_affine=False) |
| | self.modulation = Modulation(dim, double=False) |
| |
|
| | self.attn = attn_class( |
| | dim, |
| | num_heads=num_heads, |
| | qkv_bias=qkv_bias, |
| | proj_bias=proj_bias, |
| | attn_drop=attn_drop, |
| | proj_drop=drop, |
| | qk_norm=qk_norm, |
| | fused_attn=fused_attn, |
| | rope=rope, |
| | ) |
| |
|
| | self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() |
| |
|
| | self.norm2 = norm_layer(dim) |
| | mlp_hidden_dim = int(dim * mlp_ratio) |
| | self.mlp = ffn_layer( |
| | in_features=dim, |
| | hidden_features=mlp_hidden_dim, |
| | act_layer=act_layer, |
| | drop=drop, |
| | bias=ffn_bias, |
| | ) |
| | self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() |
| | self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() |
| |
|
| | self.sample_drop_ratio = drop_path |
| |
|
| | def forward(self, x: Tensor, pos=None, cond=None, is_global=False) -> Tensor: |
| | B, S = cond.shape[:2] |
| | C = x.shape[-1] |
| | if is_global: |
| | P = x.shape[1] // S |
| | cond = cond.view(B * S, C) |
| | mod, _ = self.modulation(cond) |
| |
|
| | def attn_residual_func(x: Tensor, pos=None) -> Tensor: |
| | """ |
| | conditional attention following DiT implementation from Flux |
| | https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py#L194-L239 |
| | """ |
| | def prepare_for_mod(y): |
| | """reshape to modulate the patch tokens with correct conditioning one""" |
| | return y.view(B, S, P, C).view(B * S, P, C) if is_global else y |
| | def restore_after_mod(y): |
| | """reshape back to global sequence""" |
| | return y.view(B, S, P, C).view(B, S * P, C) if is_global else y |
| |
|
| | x = prepare_for_mod(x) |
| | x = (1 + mod.scale) * self.norm1(x) + mod.shift |
| | x = restore_after_mod(x) |
| |
|
| | x = self.attn(x, pos=pos) |
| |
|
| | x = prepare_for_mod(x) |
| | x = mod.gate * x |
| | x = restore_after_mod(x) |
| |
|
| | return x |
| |
|
| | def ffn_residual_func(x: Tensor) -> Tensor: |
| | return self.ls2(self.mlp(self.norm2(x))) |
| |
|
| | if self.training and self.sample_drop_ratio > 0.1: |
| | |
| | x = drop_add_residual_stochastic_depth( |
| | x, |
| | pos=pos, |
| | residual_func=attn_residual_func, |
| | sample_drop_ratio=self.sample_drop_ratio, |
| | ) |
| | x = drop_add_residual_stochastic_depth( |
| | x, |
| | residual_func=ffn_residual_func, |
| | sample_drop_ratio=self.sample_drop_ratio, |
| | ) |
| | elif self.training and self.sample_drop_ratio > 0.0: |
| | x = x + self.drop_path1(attn_residual_func(x, pos=pos)) |
| | x = x + self.drop_path1(ffn_residual_func(x)) |
| | else: |
| | x = x + attn_residual_func(x, pos=pos) |
| | x = x + ffn_residual_func(x) |
| | return x |
| |
|
| |
|
| | class Decoder(nn.Module): |
| | """Attention blocks after encoder per DPT input feature |
| | to generate point maps at a given time. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | cfg, |
| | dim_in: int, |
| | intermediate_layer_idx: List[int] = [4, 11, 17, 23], |
| | patch_size=14, |
| | embed_dim=1024, |
| | depth=2, |
| | num_heads=16, |
| | mlp_ratio=4.0, |
| | block_fn=ConditionalBlock, |
| | qkv_bias=True, |
| | proj_bias=True, |
| | ffn_bias=True, |
| | aa_order=["frame", "global"], |
| | aa_block_size=1, |
| | qk_norm=True, |
| | rope_freq=100, |
| | init_values=0.01, |
| | ): |
| | super().__init__() |
| | self.cfg = cfg |
| | self.intermediate_layer_idx = intermediate_layer_idx |
| |
|
| | self.depth = depth |
| | self.aa_order = aa_order |
| | self.patch_size = patch_size |
| | self.aa_block_size = aa_block_size |
| |
|
| | |
| | if self.depth % self.aa_block_size != 0: |
| | raise ValueError(f"depth ({depth}) must be divisible by aa_block_size ({aa_block_size})") |
| |
|
| | self.aa_block_num = self.depth // self.aa_block_size |
| |
|
| | self.rope = ( |
| | RotaryPositionEmbedding2D(frequency=rope_freq) if rope_freq > 0 else None |
| | ) |
| | self.position_getter = PositionGetter() if self.rope is not None else None |
| |
|
| | self.dim_in = dim_in |
| |
|
| | self.old_decoder = False |
| | if self.old_decoder: |
| | self.frame_blocks = nn.ModuleList( |
| | [ |
| | block_fn( |
| | dim=embed_dim*2, |
| | num_heads=num_heads, |
| | mlp_ratio=mlp_ratio, |
| | qkv_bias=qkv_bias, |
| | proj_bias=proj_bias, |
| | ffn_bias=ffn_bias, |
| | init_values=init_values, |
| | qk_norm=qk_norm, |
| | rope=self.rope, |
| | ) |
| | for _ in range(depth) |
| | ] |
| | ) |
| | self.global_blocks = nn.ModuleList( |
| | [ |
| | block_fn( |
| | dim=embed_dim*2, |
| | num_heads=num_heads, |
| | mlp_ratio=mlp_ratio, |
| | qkv_bias=qkv_bias, |
| | proj_bias=proj_bias, |
| | ffn_bias=ffn_bias, |
| | init_values=init_values, |
| | qk_norm=qk_norm, |
| | rope=self.rope, |
| | ) |
| | for _ in range(depth) |
| | ] |
| | ) |
| | else: |
| | depths = [depth] |
| | self.frame_blocks = nn.ModuleList([ |
| | nn.ModuleList([ |
| | block_fn( |
| | dim=embed_dim*2, |
| | num_heads=num_heads, |
| | mlp_ratio=mlp_ratio, |
| | qkv_bias=qkv_bias, |
| | proj_bias=proj_bias, |
| | ffn_bias=ffn_bias, |
| | init_values=init_values, |
| | qk_norm=qk_norm, |
| | rope=self.rope, |
| | ) |
| | for _ in range(d) |
| | ]) |
| | for d in depths |
| | ]) |
| |
|
| | self.global_blocks = nn.ModuleList([ |
| | nn.ModuleList([ |
| | block_fn( |
| | dim=embed_dim*2, |
| | num_heads=num_heads, |
| | mlp_ratio=mlp_ratio, |
| | qkv_bias=qkv_bias, |
| | proj_bias=proj_bias, |
| | ffn_bias=ffn_bias, |
| | init_values=init_values, |
| | qk_norm=qk_norm, |
| | rope=self.rope, |
| | ) |
| | for _ in range(d) |
| | ]) |
| | for d in depths |
| | ]) |
| |
|
| | self.use_reentrant = False |
| |
|
| | def get_condition_tokens( |
| | self, |
| | aggregated_tokens_list: List[torch.Tensor], |
| | cond_view_idxs: torch.Tensor |
| | ): |
| | |
| | tokens_last = aggregated_tokens_list[-1] |
| | |
| | cond_token_idx = 1 |
| | camera_tokens = tokens_last[:, :, [cond_token_idx]] |
| |
|
| | cond_view_idxs = cond_view_idxs.to(camera_tokens.device) |
| | cond_view_idxs = repeat( |
| | cond_view_idxs, |
| | "b s -> b s c d", |
| | c=camera_tokens.shape[2], |
| | d=camera_tokens.shape[3], |
| | ) |
| | cond_tokens = torch.gather(camera_tokens, 1, cond_view_idxs) |
| |
|
| | return cond_tokens |
| |
|
| | def forward( |
| | self, |
| | images: torch.Tensor, |
| | aggregated_tokens_list: List[torch.Tensor], |
| | patch_start_idx: int, |
| | cond_view_idxs: torch.Tensor, |
| | ): |
| | B, S, _, H, W = images.shape |
| |
|
| | cond_tokens = self.get_condition_tokens( |
| | aggregated_tokens_list, cond_view_idxs |
| | ) |
| |
|
| | input_tokens = [] |
| | for k, layer_idx in enumerate(self.intermediate_layer_idx): |
| | layer_tokens = aggregated_tokens_list[layer_idx].clone() |
| | input_tokens.append(layer_tokens) |
| |
|
| | _, _, P, C = input_tokens[0].shape |
| |
|
| | pos = None |
| | if self.rope is not None: |
| | pos = self.position_getter( |
| | B * S, H // self.patch_size, W // self.patch_size, device=images.device |
| | ) |
| | if patch_start_idx > 0: |
| | |
| | |
| | pos = pos + 1 |
| | pos_special = torch.zeros(B * S, patch_start_idx, 2).to(images.device).to(pos.dtype) |
| | pos = torch.cat([pos_special, pos], dim=1) |
| |
|
| | frame_idx = 0 |
| | global_idx = 0 |
| | depth = len(self.frame_blocks[0]) |
| | N = len(input_tokens) |
| | |
| | |
| | s_tokens = torch.cat(input_tokens) |
| | s_cond_tokens = torch.cat([cond_tokens] * N, dim=0) |
| | s_pos = torch.cat([pos] * N, dim=0) |
| |
|
| | |
| | for _ in range(depth): |
| | for attn_type in self.aa_order: |
| | token_idx = 0 |
| |
|
| | if attn_type == "frame": |
| | s_tokens, frame_idx, _ = self._process_frame_attention( |
| | s_tokens, s_cond_tokens, B * N, S, P, C, frame_idx, pos=s_pos, token_idx=token_idx |
| | ) |
| | elif attn_type == "global": |
| | s_tokens, global_idx, _ = self._process_global_attention( |
| | s_tokens, s_cond_tokens, B * N, S, P, C, global_idx, pos=s_pos, token_idx=token_idx |
| | ) |
| | else: |
| | raise ValueError(f"Unknown attention type: {attn_type}") |
| | processed = [t.view(B, S, P, C) for t in s_tokens.split(B, dim=0)] |
| |
|
| | return processed |
| |
|
| | def _process_frame_attention(self, tokens, cond_tokens, B, S, P, C, frame_idx, pos=None, token_idx=0): |
| | """ |
| | Process frame attention blocks. We keep tokens in shape (B*S, P, C). |
| | """ |
| | |
| | if tokens.shape != (B * S, P, C): |
| | tokens = tokens.view(B, S, P, C).view(B * S, P, C) |
| |
|
| | if pos is not None and pos.shape != (B * S, P, 2): |
| | pos = pos.view(B, S, P, 2).view(B * S, P, 2) |
| |
|
| | intermediates = [] |
| | |
| | for _ in range(self.aa_block_size): |
| | if self.training: |
| | tokens = checkpoint(self.frame_blocks[token_idx][frame_idx], tokens, pos, cond_tokens, use_reentrant=self.use_reentrant) |
| | else: |
| | if self.old_decoder: |
| | tokens = self.frame_blocks[frame_idx](tokens, pos=pos, cond=cond_tokens) |
| | else: |
| | tokens = self.frame_blocks[0][frame_idx](tokens, pos=pos, cond=cond_tokens) |
| |
|
| | frame_idx += 1 |
| | intermediates.append(tokens.view(B, S, P, C)) |
| |
|
| | return tokens, frame_idx, intermediates |
| |
|
| | def _process_global_attention(self, tokens, cond_tokens, B, S, P, C, global_idx, pos=None, token_idx=0): |
| | """ |
| | Process global attention blocks. We keep tokens in shape (B, S*P, C). |
| | """ |
| | if tokens.shape != (B, S * P, C): |
| | tokens = tokens.view(B, S, P, C).view(B, S * P, C) |
| |
|
| | if pos is not None and pos.shape != (B, S * P, 2): |
| | pos = pos.view(B, S, P, 2).view(B, S * P, 2) |
| |
|
| | intermediates = [] |
| |
|
| | |
| | for _ in range(self.aa_block_size): |
| | if self.training: |
| | tokens = checkpoint(self.global_blocks[token_idx][global_idx], tokens, pos, cond_tokens, True, use_reentrant=self.use_reentrant) |
| | else: |
| | if self.old_decoder: |
| | tokens = self.global_blocks[global_idx](tokens, pos=pos, cond=cond_tokens, is_global=True) |
| | else: |
| | tokens = self.global_blocks[0][global_idx](tokens, pos=pos, cond=cond_tokens, is_global=True) |
| | global_idx += 1 |
| | intermediates.append(tokens.view(B, S, P, C)) |
| |
|
| | return tokens, global_idx, intermediates |
| |
|