| |
| |
| |
| |
| |
|
|
| import logging |
| import torch |
| from torch import nn, Tensor |
| from torch.utils.checkpoint import checkpoint |
| from typing import List, Callable |
| from dataclasses import dataclass |
|
|
| from einops import repeat |
|
|
| from vggt.layers.block import drop_add_residual_stochastic_depth |
| from vggt.layers.rope import RotaryPositionEmbedding2D, PositionGetter |
|
|
| from vggt.layers.attention import Attention |
| from vggt.layers.drop_path import DropPath |
| from vggt.layers.layer_scale import LayerScale |
| from vggt.layers.mlp import Mlp |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| @dataclass |
| class ModulationOut: |
| shift: Tensor |
| scale: Tensor |
| gate: Tensor |
|
|
|
|
| class Modulation(nn.Module): |
| def __init__(self, dim: int, double: bool): |
| super().__init__() |
| self.is_double = double |
| self.multiplier = 6 if double else 3 |
| self.lin = nn.Linear(dim, self.multiplier * dim, bias=True) |
|
|
| def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]: |
| out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1) |
|
|
| return ( |
| ModulationOut(*out[:3]), |
| ModulationOut(*out[3:]) if self.is_double else None, |
| ) |
|
|
|
|
| class ConditionalBlock(nn.Module): |
| def __init__( |
| self, |
| dim: int, |
| num_heads: int, |
| mlp_ratio: float = 4.0, |
| qkv_bias: bool = True, |
| proj_bias: bool = True, |
| ffn_bias: bool = True, |
| drop: float = 0.0, |
| attn_drop: float = 0.0, |
| init_values=None, |
| drop_path: float = 0.0, |
| act_layer: Callable[..., nn.Module] = nn.GELU, |
| norm_layer: Callable[..., nn.Module] = nn.LayerNorm, |
| attn_class: Callable[..., nn.Module] = Attention, |
| ffn_layer: Callable[..., nn.Module] = Mlp, |
| qk_norm: bool = False, |
| fused_attn: bool = True, |
| rope=None, |
| ) -> None: |
| super().__init__() |
|
|
| self.norm1 = norm_layer(dim, elementwise_affine=False) |
| self.modulation = Modulation(dim, double=False) |
|
|
| self.attn = attn_class( |
| dim, |
| num_heads=num_heads, |
| qkv_bias=qkv_bias, |
| proj_bias=proj_bias, |
| attn_drop=attn_drop, |
| proj_drop=drop, |
| qk_norm=qk_norm, |
| fused_attn=fused_attn, |
| rope=rope, |
| ) |
|
|
| self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() |
|
|
| self.norm2 = norm_layer(dim) |
| mlp_hidden_dim = int(dim * mlp_ratio) |
| self.mlp = ffn_layer( |
| in_features=dim, |
| hidden_features=mlp_hidden_dim, |
| act_layer=act_layer, |
| drop=drop, |
| bias=ffn_bias, |
| ) |
| self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() |
| self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() |
|
|
| self.sample_drop_ratio = drop_path |
|
|
| def forward(self, x: Tensor, pos=None, cond=None, is_global=False) -> Tensor: |
| B, S = cond.shape[:2] |
| C = x.shape[-1] |
| if is_global: |
| P = x.shape[1] // S |
| cond = cond.view(B * S, C) |
| mod, _ = self.modulation(cond) |
|
|
| def attn_residual_func(x: Tensor, pos=None) -> Tensor: |
| """ |
| conditional attention following DiT implementation from Flux |
| https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py#L194-L239 |
| """ |
| def prepare_for_mod(y): |
| """reshape to modulate the patch tokens with correct conditioning one""" |
| return y.view(B, S, P, C).view(B * S, P, C) if is_global else y |
| def restore_after_mod(y): |
| """reshape back to global sequence""" |
| return y.view(B, S, P, C).view(B, S * P, C) if is_global else y |
|
|
| x = prepare_for_mod(x) |
| x = (1 + mod.scale) * self.norm1(x) + mod.shift |
| x = restore_after_mod(x) |
|
|
| x = self.attn(x, pos=pos) |
|
|
| x = prepare_for_mod(x) |
| x = mod.gate * x |
| x = restore_after_mod(x) |
|
|
| return x |
|
|
| def ffn_residual_func(x: Tensor) -> Tensor: |
| return self.ls2(self.mlp(self.norm2(x))) |
|
|
| if self.training and self.sample_drop_ratio > 0.1: |
| |
| x = drop_add_residual_stochastic_depth( |
| x, |
| pos=pos, |
| residual_func=attn_residual_func, |
| sample_drop_ratio=self.sample_drop_ratio, |
| ) |
| x = drop_add_residual_stochastic_depth( |
| x, |
| residual_func=ffn_residual_func, |
| sample_drop_ratio=self.sample_drop_ratio, |
| ) |
| elif self.training and self.sample_drop_ratio > 0.0: |
| x = x + self.drop_path1(attn_residual_func(x, pos=pos)) |
| x = x + self.drop_path1(ffn_residual_func(x)) |
| else: |
| x = x + attn_residual_func(x, pos=pos) |
| x = x + ffn_residual_func(x) |
| return x |
|
|
|
|
| class Decoder(nn.Module): |
| """Attention blocks after encoder per DPT input feature |
| to generate point maps at a given time. |
| """ |
|
|
| def __init__( |
| self, |
| cfg, |
| dim_in: int, |
| intermediate_layer_idx: List[int] = [4, 11, 17, 23], |
| patch_size=14, |
| embed_dim=1024, |
| depth=2, |
| num_heads=16, |
| mlp_ratio=4.0, |
| block_fn=ConditionalBlock, |
| qkv_bias=True, |
| proj_bias=True, |
| ffn_bias=True, |
| aa_order=["frame", "global"], |
| aa_block_size=1, |
| qk_norm=True, |
| rope_freq=100, |
| init_values=0.01, |
| ): |
| super().__init__() |
| self.cfg = cfg |
| self.intermediate_layer_idx = intermediate_layer_idx |
|
|
| self.depth = depth |
| self.aa_order = aa_order |
| self.patch_size = patch_size |
| self.aa_block_size = aa_block_size |
|
|
| |
| if self.depth % self.aa_block_size != 0: |
| raise ValueError(f"depth ({depth}) must be divisible by aa_block_size ({aa_block_size})") |
|
|
| self.aa_block_num = self.depth // self.aa_block_size |
|
|
| self.rope = ( |
| RotaryPositionEmbedding2D(frequency=rope_freq) if rope_freq > 0 else None |
| ) |
| self.position_getter = PositionGetter() if self.rope is not None else None |
|
|
| self.dim_in = dim_in |
|
|
| self.old_decoder = False |
| if self.old_decoder: |
| self.frame_blocks = nn.ModuleList( |
| [ |
| block_fn( |
| dim=embed_dim*2, |
| num_heads=num_heads, |
| mlp_ratio=mlp_ratio, |
| qkv_bias=qkv_bias, |
| proj_bias=proj_bias, |
| ffn_bias=ffn_bias, |
| init_values=init_values, |
| qk_norm=qk_norm, |
| rope=self.rope, |
| ) |
| for _ in range(depth) |
| ] |
| ) |
| self.global_blocks = nn.ModuleList( |
| [ |
| block_fn( |
| dim=embed_dim*2, |
| num_heads=num_heads, |
| mlp_ratio=mlp_ratio, |
| qkv_bias=qkv_bias, |
| proj_bias=proj_bias, |
| ffn_bias=ffn_bias, |
| init_values=init_values, |
| qk_norm=qk_norm, |
| rope=self.rope, |
| ) |
| for _ in range(depth) |
| ] |
| ) |
| else: |
| depths = [depth] |
| self.frame_blocks = nn.ModuleList([ |
| nn.ModuleList([ |
| block_fn( |
| dim=embed_dim*2, |
| num_heads=num_heads, |
| mlp_ratio=mlp_ratio, |
| qkv_bias=qkv_bias, |
| proj_bias=proj_bias, |
| ffn_bias=ffn_bias, |
| init_values=init_values, |
| qk_norm=qk_norm, |
| rope=self.rope, |
| ) |
| for _ in range(d) |
| ]) |
| for d in depths |
| ]) |
|
|
| self.global_blocks = nn.ModuleList([ |
| nn.ModuleList([ |
| block_fn( |
| dim=embed_dim*2, |
| num_heads=num_heads, |
| mlp_ratio=mlp_ratio, |
| qkv_bias=qkv_bias, |
| proj_bias=proj_bias, |
| ffn_bias=ffn_bias, |
| init_values=init_values, |
| qk_norm=qk_norm, |
| rope=self.rope, |
| ) |
| for _ in range(d) |
| ]) |
| for d in depths |
| ]) |
|
|
| self.use_reentrant = False |
|
|
| def get_condition_tokens( |
| self, |
| aggregated_tokens_list: List[torch.Tensor], |
| cond_view_idxs: torch.Tensor |
| ): |
| |
| tokens_last = aggregated_tokens_list[-1] |
| |
| cond_token_idx = 1 |
| camera_tokens = tokens_last[:, :, [cond_token_idx]] |
|
|
| cond_view_idxs = cond_view_idxs.to(camera_tokens.device) |
| cond_view_idxs = repeat( |
| cond_view_idxs, |
| "b s -> b s c d", |
| c=camera_tokens.shape[2], |
| d=camera_tokens.shape[3], |
| ) |
| cond_tokens = torch.gather(camera_tokens, 1, cond_view_idxs) |
|
|
| return cond_tokens |
|
|
| def forward( |
| self, |
| images: torch.Tensor, |
| aggregated_tokens_list: List[torch.Tensor], |
| patch_start_idx: int, |
| cond_view_idxs: torch.Tensor, |
| ): |
| B, S, _, H, W = images.shape |
|
|
| cond_tokens = self.get_condition_tokens( |
| aggregated_tokens_list, cond_view_idxs |
| ) |
|
|
| input_tokens = [] |
| for k, layer_idx in enumerate(self.intermediate_layer_idx): |
| layer_tokens = aggregated_tokens_list[layer_idx].clone() |
| input_tokens.append(layer_tokens) |
|
|
| _, _, P, C = input_tokens[0].shape |
|
|
| pos = None |
| if self.rope is not None: |
| pos = self.position_getter( |
| B * S, H // self.patch_size, W // self.patch_size, device=images.device |
| ) |
| if patch_start_idx > 0: |
| |
| |
| pos = pos + 1 |
| pos_special = torch.zeros(B * S, patch_start_idx, 2).to(images.device).to(pos.dtype) |
| pos = torch.cat([pos_special, pos], dim=1) |
|
|
| frame_idx = 0 |
| global_idx = 0 |
| depth = len(self.frame_blocks[0]) |
| N = len(input_tokens) |
| |
| |
| s_tokens = torch.cat(input_tokens) |
| s_cond_tokens = torch.cat([cond_tokens] * N, dim=0) |
| s_pos = torch.cat([pos] * N, dim=0) |
|
|
| |
| for _ in range(depth): |
| for attn_type in self.aa_order: |
| token_idx = 0 |
|
|
| if attn_type == "frame": |
| s_tokens, frame_idx, _ = self._process_frame_attention( |
| s_tokens, s_cond_tokens, B * N, S, P, C, frame_idx, pos=s_pos, token_idx=token_idx |
| ) |
| elif attn_type == "global": |
| s_tokens, global_idx, _ = self._process_global_attention( |
| s_tokens, s_cond_tokens, B * N, S, P, C, global_idx, pos=s_pos, token_idx=token_idx |
| ) |
| else: |
| raise ValueError(f"Unknown attention type: {attn_type}") |
| processed = [t.view(B, S, P, C) for t in s_tokens.split(B, dim=0)] |
|
|
| return processed |
|
|
| def _process_frame_attention(self, tokens, cond_tokens, B, S, P, C, frame_idx, pos=None, token_idx=0): |
| """ |
| Process frame attention blocks. We keep tokens in shape (B*S, P, C). |
| """ |
| |
| if tokens.shape != (B * S, P, C): |
| tokens = tokens.view(B, S, P, C).view(B * S, P, C) |
|
|
| if pos is not None and pos.shape != (B * S, P, 2): |
| pos = pos.view(B, S, P, 2).view(B * S, P, 2) |
|
|
| intermediates = [] |
| |
| for _ in range(self.aa_block_size): |
| if self.training: |
| tokens = checkpoint(self.frame_blocks[token_idx][frame_idx], tokens, pos, cond_tokens, use_reentrant=self.use_reentrant) |
| else: |
| if self.old_decoder: |
| tokens = self.frame_blocks[frame_idx](tokens, pos=pos, cond=cond_tokens) |
| else: |
| tokens = self.frame_blocks[0][frame_idx](tokens, pos=pos, cond=cond_tokens) |
|
|
| frame_idx += 1 |
| intermediates.append(tokens.view(B, S, P, C)) |
|
|
| return tokens, frame_idx, intermediates |
|
|
| def _process_global_attention(self, tokens, cond_tokens, B, S, P, C, global_idx, pos=None, token_idx=0): |
| """ |
| Process global attention blocks. We keep tokens in shape (B, S*P, C). |
| """ |
| if tokens.shape != (B, S * P, C): |
| tokens = tokens.view(B, S, P, C).view(B, S * P, C) |
|
|
| if pos is not None and pos.shape != (B, S * P, 2): |
| pos = pos.view(B, S, P, 2).view(B, S * P, 2) |
|
|
| intermediates = [] |
|
|
| |
| for _ in range(self.aa_block_size): |
| if self.training: |
| tokens = checkpoint(self.global_blocks[token_idx][global_idx], tokens, pos, cond_tokens, True, use_reentrant=self.use_reentrant) |
| else: |
| if self.old_decoder: |
| tokens = self.global_blocks[global_idx](tokens, pos=pos, cond=cond_tokens, is_global=True) |
| else: |
| tokens = self.global_blocks[0][global_idx](tokens, pos=pos, cond=cond_tokens, is_global=True) |
| global_idx += 1 |
| intermediates.append(tokens.view(B, S, P, C)) |
|
|
| return tokens, global_idx, intermediates |
|
|