vdpm / dpm /decoder.py
Edgar Sucar
licences
da72871
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE-VGGT file in the root directory of this source tree.
import logging
import torch
from torch import nn, Tensor
from torch.utils.checkpoint import checkpoint
from typing import List, Callable
from dataclasses import dataclass
from einops import repeat
from vggt.layers.block import drop_add_residual_stochastic_depth
from vggt.layers.rope import RotaryPositionEmbedding2D, PositionGetter
from vggt.layers.attention import Attention
from vggt.layers.drop_path import DropPath
from vggt.layers.layer_scale import LayerScale
from vggt.layers.mlp import Mlp
logger = logging.getLogger(__name__)
@dataclass
class ModulationOut:
shift: Tensor
scale: Tensor
gate: Tensor
class Modulation(nn.Module):
def __init__(self, dim: int, double: bool):
super().__init__()
self.is_double = double
self.multiplier = 6 if double else 3
self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:
out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
return (
ModulationOut(*out[:3]),
ModulationOut(*out[3:]) if self.is_double else None,
)
class ConditionalBlock(nn.Module):
def __init__(
self,
dim: int,
num_heads: int,
mlp_ratio: float = 4.0,
qkv_bias: bool = True,
proj_bias: bool = True,
ffn_bias: bool = True,
drop: float = 0.0,
attn_drop: float = 0.0,
init_values=None,
drop_path: float = 0.0,
act_layer: Callable[..., nn.Module] = nn.GELU,
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
attn_class: Callable[..., nn.Module] = Attention,
ffn_layer: Callable[..., nn.Module] = Mlp,
qk_norm: bool = False,
fused_attn: bool = True, # use F.scaled_dot_product_attention or not
rope=None,
) -> None:
super().__init__()
self.norm1 = norm_layer(dim, elementwise_affine=False)
self.modulation = Modulation(dim, double=False)
self.attn = attn_class(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
proj_bias=proj_bias,
attn_drop=attn_drop,
proj_drop=drop,
qk_norm=qk_norm,
fused_attn=fused_attn,
rope=rope,
)
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = ffn_layer(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop,
bias=ffn_bias,
)
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.sample_drop_ratio = drop_path
def forward(self, x: Tensor, pos=None, cond=None, is_global=False) -> Tensor:
B, S = cond.shape[:2]
C = x.shape[-1]
if is_global:
P = x.shape[1] // S
cond = cond.view(B * S, C)
mod, _ = self.modulation(cond)
def attn_residual_func(x: Tensor, pos=None) -> Tensor:
"""
conditional attention following DiT implementation from Flux
https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py#L194-L239
"""
def prepare_for_mod(y):
"""reshape to modulate the patch tokens with correct conditioning one"""
return y.view(B, S, P, C).view(B * S, P, C) if is_global else y
def restore_after_mod(y):
"""reshape back to global sequence"""
return y.view(B, S, P, C).view(B, S * P, C) if is_global else y
x = prepare_for_mod(x)
x = (1 + mod.scale) * self.norm1(x) + mod.shift
x = restore_after_mod(x)
x = self.attn(x, pos=pos)
x = prepare_for_mod(x)
x = mod.gate * x
x = restore_after_mod(x)
return x
def ffn_residual_func(x: Tensor) -> Tensor:
return self.ls2(self.mlp(self.norm2(x)))
if self.training and self.sample_drop_ratio > 0.1:
# the overhead is compensated only for a drop path rate larger than 0.1
x = drop_add_residual_stochastic_depth(
x,
pos=pos,
residual_func=attn_residual_func,
sample_drop_ratio=self.sample_drop_ratio,
)
x = drop_add_residual_stochastic_depth(
x,
residual_func=ffn_residual_func,
sample_drop_ratio=self.sample_drop_ratio,
)
elif self.training and self.sample_drop_ratio > 0.0:
x = x + self.drop_path1(attn_residual_func(x, pos=pos))
x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
else:
x = x + attn_residual_func(x, pos=pos)
x = x + ffn_residual_func(x)
return x
class Decoder(nn.Module):
"""Attention blocks after encoder per DPT input feature
to generate point maps at a given time.
"""
def __init__(
self,
cfg,
dim_in: int,
intermediate_layer_idx: List[int] = [4, 11, 17, 23],
patch_size=14,
embed_dim=1024,
depth=2,
num_heads=16,
mlp_ratio=4.0,
block_fn=ConditionalBlock,
qkv_bias=True,
proj_bias=True,
ffn_bias=True,
aa_order=["frame", "global"],
aa_block_size=1,
qk_norm=True,
rope_freq=100,
init_values=0.01,
):
super().__init__()
self.cfg = cfg
self.intermediate_layer_idx = intermediate_layer_idx
self.depth = depth
self.aa_order = aa_order
self.patch_size = patch_size
self.aa_block_size = aa_block_size
# Validate that depth is divisible by aa_block_size
if self.depth % self.aa_block_size != 0:
raise ValueError(f"depth ({depth}) must be divisible by aa_block_size ({aa_block_size})")
self.aa_block_num = self.depth // self.aa_block_size
self.rope = (
RotaryPositionEmbedding2D(frequency=rope_freq) if rope_freq > 0 else None
)
self.position_getter = PositionGetter() if self.rope is not None else None
self.dim_in = dim_in
self.old_decoder = False
if self.old_decoder:
self.frame_blocks = nn.ModuleList(
[
block_fn(
dim=embed_dim*2,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
proj_bias=proj_bias,
ffn_bias=ffn_bias,
init_values=init_values,
qk_norm=qk_norm,
rope=self.rope,
)
for _ in range(depth)
]
)
self.global_blocks = nn.ModuleList(
[
block_fn(
dim=embed_dim*2,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
proj_bias=proj_bias,
ffn_bias=ffn_bias,
init_values=init_values,
qk_norm=qk_norm,
rope=self.rope,
)
for _ in range(depth)
]
)
else:
depths = [depth]
self.frame_blocks = nn.ModuleList([
nn.ModuleList([
block_fn(
dim=embed_dim*2,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
proj_bias=proj_bias,
ffn_bias=ffn_bias,
init_values=init_values,
qk_norm=qk_norm,
rope=self.rope,
)
for _ in range(d)
])
for d in depths
])
self.global_blocks = nn.ModuleList([
nn.ModuleList([
block_fn(
dim=embed_dim*2,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
proj_bias=proj_bias,
ffn_bias=ffn_bias,
init_values=init_values,
qk_norm=qk_norm,
rope=self.rope,
)
for _ in range(d)
])
for d in depths
])
self.use_reentrant = False # hardcoded to False
def get_condition_tokens(
self,
aggregated_tokens_list: List[torch.Tensor],
cond_view_idxs: torch.Tensor
):
# Use tokens from the last block for conditioning
tokens_last = aggregated_tokens_list[-1] # [B S N_tok D]
# Extract the camera tokens
cond_token_idx = 1
camera_tokens = tokens_last[:, :, [cond_token_idx]] # [B S D]
cond_view_idxs = cond_view_idxs.to(camera_tokens.device)
cond_view_idxs = repeat(
cond_view_idxs,
"b s -> b s c d",
c=camera_tokens.shape[2],
d=camera_tokens.shape[3],
)
cond_tokens = torch.gather(camera_tokens, 1, cond_view_idxs)
return cond_tokens
def forward(
self,
images: torch.Tensor,
aggregated_tokens_list: List[torch.Tensor],
patch_start_idx: int,
cond_view_idxs: torch.Tensor,
):
B, S, _, H, W = images.shape
cond_tokens = self.get_condition_tokens(
aggregated_tokens_list, cond_view_idxs
)
input_tokens = []
for k, layer_idx in enumerate(self.intermediate_layer_idx):
layer_tokens = aggregated_tokens_list[layer_idx].clone()
input_tokens.append(layer_tokens)
_, _, P, C = input_tokens[0].shape
pos = None
if self.rope is not None:
pos = self.position_getter(
B * S, H // self.patch_size, W // self.patch_size, device=images.device
)
if patch_start_idx > 0:
# do not use position embedding for special tokens (camera and register tokens)
# so set pos to 0 for the special tokens
pos = pos + 1
pos_special = torch.zeros(B * S, patch_start_idx, 2).to(images.device).to(pos.dtype)
pos = torch.cat([pos_special, pos], dim=1)
frame_idx = 0
global_idx = 0
depth = len(self.frame_blocks[0])
N = len(input_tokens)
# stack all intermediate layer tokens along batch dimension
# they are all processed by the same decoder
s_tokens = torch.cat(input_tokens)
s_cond_tokens = torch.cat([cond_tokens] * N, dim=0)
s_pos = torch.cat([pos] * N, dim=0)
# perform time conditioned attention
for _ in range(depth):
for attn_type in self.aa_order:
token_idx = 0
if attn_type == "frame":
s_tokens, frame_idx, _ = self._process_frame_attention(
s_tokens, s_cond_tokens, B * N, S, P, C, frame_idx, pos=s_pos, token_idx=token_idx
)
elif attn_type == "global":
s_tokens, global_idx, _ = self._process_global_attention(
s_tokens, s_cond_tokens, B * N, S, P, C, global_idx, pos=s_pos, token_idx=token_idx
)
else:
raise ValueError(f"Unknown attention type: {attn_type}")
processed = [t.view(B, S, P, C) for t in s_tokens.split(B, dim=0)]
return processed
def _process_frame_attention(self, tokens, cond_tokens, B, S, P, C, frame_idx, pos=None, token_idx=0):
"""
Process frame attention blocks. We keep tokens in shape (B*S, P, C).
"""
# If needed, reshape tokens or positions:
if tokens.shape != (B * S, P, C):
tokens = tokens.view(B, S, P, C).view(B * S, P, C)
if pos is not None and pos.shape != (B * S, P, 2):
pos = pos.view(B, S, P, 2).view(B * S, P, 2)
intermediates = []
# by default, self.aa_block_size=1, which processes one block at a time
for _ in range(self.aa_block_size):
if self.training:
tokens = checkpoint(self.frame_blocks[token_idx][frame_idx], tokens, pos, cond_tokens, use_reentrant=self.use_reentrant)
else:
if self.old_decoder:
tokens = self.frame_blocks[frame_idx](tokens, pos=pos, cond=cond_tokens)
else:
tokens = self.frame_blocks[0][frame_idx](tokens, pos=pos, cond=cond_tokens)
frame_idx += 1
intermediates.append(tokens.view(B, S, P, C))
return tokens, frame_idx, intermediates
def _process_global_attention(self, tokens, cond_tokens, B, S, P, C, global_idx, pos=None, token_idx=0):
"""
Process global attention blocks. We keep tokens in shape (B, S*P, C).
"""
if tokens.shape != (B, S * P, C):
tokens = tokens.view(B, S, P, C).view(B, S * P, C)
if pos is not None and pos.shape != (B, S * P, 2):
pos = pos.view(B, S, P, 2).view(B, S * P, 2)
intermediates = []
# by default, self.aa_block_size=1, which processes one block at a time
for _ in range(self.aa_block_size):
if self.training:
tokens = checkpoint(self.global_blocks[token_idx][global_idx], tokens, pos, cond_tokens, True, use_reentrant=self.use_reentrant)
else:
if self.old_decoder:
tokens = self.global_blocks[global_idx](tokens, pos=pos, cond=cond_tokens, is_global=True)
else:
tokens = self.global_blocks[0][global_idx](tokens, pos=pos, cond=cond_tokens, is_global=True)
global_idx += 1
intermediates.append(tokens.view(B, S, P, C))
return tokens, global_idx, intermediates