| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | from dataclasses import dataclass |
| | from typing import List, Optional, Tuple, Union, Callable |
| | import torch |
| | from torch import nn |
| |
|
| | from common.cache import Cache |
| | from common.distributed.ops import slice_inputs |
| |
|
| | from . import na |
| | from .embedding import TimeEmbedding |
| | from .modulation import get_ada_layer |
| | from .nablocks import get_nablock |
| | from .normalization import get_norm_layer |
| | from .patch import get_na_patch_layers |
| |
|
| | |
| | def gradient_checkpointing(module: Union[Callable, nn.Module], *args, enabled: bool, **kwargs): |
| | return module(*args, **kwargs) |
| |
|
| | @dataclass |
| | class NaDiTOutput: |
| | vid_sample: torch.Tensor |
| |
|
| |
|
| | class NaDiT(nn.Module): |
| | """ |
| | Native Resolution Diffusion Transformer (NaDiT) |
| | """ |
| |
|
| | gradient_checkpointing = False |
| |
|
| | def __init__( |
| | self, |
| | vid_in_channels: int, |
| | vid_out_channels: int, |
| | vid_dim: int, |
| | txt_in_dim: Union[int, List[int]], |
| | txt_dim: Optional[int], |
| | emb_dim: int, |
| | heads: int, |
| | head_dim: int, |
| | expand_ratio: int, |
| | norm: Optional[str], |
| | norm_eps: float, |
| | ada: str, |
| | qk_bias: bool, |
| | qk_norm: Optional[str], |
| | patch_size: Union[int, Tuple[int, int, int]], |
| | num_layers: int, |
| | block_type: Union[str, Tuple[str]], |
| | mm_layers: Union[int, Tuple[bool]], |
| | mlp_type: str = "normal", |
| | patch_type: str = "v1", |
| | rope_type: Optional[str] = "rope3d", |
| | rope_dim: Optional[int] = None, |
| | window: Optional[Tuple] = None, |
| | window_method: Optional[Tuple[str]] = None, |
| | msa_type: Optional[Tuple[str]] = None, |
| | mca_type: Optional[Tuple[str]] = None, |
| | txt_in_norm: Optional[str] = None, |
| | txt_in_norm_scale_factor: int = 0.01, |
| | txt_proj_type: Optional[str] = "linear", |
| | vid_out_norm: Optional[str] = None, |
| | **kwargs, |
| | ): |
| | ada = get_ada_layer(ada) |
| | norm = get_norm_layer(norm) |
| | qk_norm = get_norm_layer(qk_norm) |
| | rope_dim = rope_dim if rope_dim is not None else head_dim // 2 |
| | if isinstance(block_type, str): |
| | block_type = [block_type] * num_layers |
| | elif len(block_type) != num_layers: |
| | raise ValueError("The ``block_type`` list should equal to ``num_layers``.") |
| | super().__init__() |
| | NaPatchIn, NaPatchOut = get_na_patch_layers(patch_type) |
| | self.vid_in = NaPatchIn( |
| | in_channels=vid_in_channels, |
| | patch_size=patch_size, |
| | dim=vid_dim, |
| | ) |
| | if not isinstance(txt_in_dim, int): |
| | self.txt_in = nn.ModuleList([]) |
| | for in_dim in txt_in_dim: |
| | txt_norm_layer = get_norm_layer(txt_in_norm)(txt_dim, norm_eps, True) |
| | if txt_proj_type == "linear": |
| | txt_proj_layer = nn.Linear(in_dim, txt_dim) |
| | else: |
| | txt_proj_layer = nn.Sequential( |
| | nn.Linear(in_dim, in_dim), nn.GELU("tanh"), nn.Linear(in_dim, txt_dim) |
| | ) |
| | torch.nn.init.constant_(txt_norm_layer.weight, txt_in_norm_scale_factor) |
| | self.txt_in.append( |
| | nn.Sequential( |
| | txt_proj_layer, |
| | txt_norm_layer, |
| | ) |
| | ) |
| | else: |
| | self.txt_in = ( |
| | nn.Linear(txt_in_dim, txt_dim) |
| | if txt_in_dim and txt_in_dim != txt_dim |
| | else nn.Identity() |
| | ) |
| | self.emb_in = TimeEmbedding( |
| | sinusoidal_dim=256, |
| | hidden_dim=max(vid_dim, txt_dim), |
| | output_dim=emb_dim, |
| | ) |
| |
|
| | if window is None or isinstance(window[0], int): |
| | window = [window] * num_layers |
| | if window_method is None or isinstance(window_method, str): |
| | window_method = [window_method] * num_layers |
| |
|
| | if msa_type is None or isinstance(msa_type, str): |
| | msa_type = [msa_type] * num_layers |
| | if mca_type is None or isinstance(mca_type, str): |
| | mca_type = [mca_type] * num_layers |
| |
|
| | self.blocks = nn.ModuleList( |
| | [ |
| | get_nablock(block_type[i])( |
| | vid_dim=vid_dim, |
| | txt_dim=txt_dim, |
| | emb_dim=emb_dim, |
| | heads=heads, |
| | head_dim=head_dim, |
| | expand_ratio=expand_ratio, |
| | norm=norm, |
| | norm_eps=norm_eps, |
| | ada=ada, |
| | qk_bias=qk_bias, |
| | qk_norm=qk_norm, |
| | shared_weights=not ( |
| | (i < mm_layers) if isinstance(mm_layers, int) else mm_layers[i] |
| | ), |
| | mlp_type=mlp_type, |
| | window=window[i], |
| | window_method=window_method[i], |
| | msa_type=msa_type[i], |
| | mca_type=mca_type[i], |
| | rope_type=rope_type, |
| | rope_dim=rope_dim, |
| | is_last_layer=(i == num_layers - 1), |
| | **kwargs, |
| | ) |
| | for i in range(num_layers) |
| | ] |
| | ) |
| |
|
| | self.vid_out_norm = None |
| | if vid_out_norm is not None: |
| | self.vid_out_norm = get_norm_layer(vid_out_norm)( |
| | dim=vid_dim, |
| | eps=norm_eps, |
| | elementwise_affine=True, |
| | ) |
| | self.vid_out_ada = ada( |
| | dim=vid_dim, |
| | emb_dim=emb_dim, |
| | layers=["out"], |
| | modes=["in"], |
| | ) |
| |
|
| | self.vid_out = NaPatchOut( |
| | out_channels=vid_out_channels, |
| | patch_size=patch_size, |
| | dim=vid_dim, |
| | ) |
| |
|
| | def set_gradient_checkpointing(self, enable: bool): |
| | self.gradient_checkpointing = enable |
| |
|
| | def forward( |
| | self, |
| | vid: torch.FloatTensor, |
| | txt: Union[torch.FloatTensor, List[torch.FloatTensor]], |
| | vid_shape: torch.LongTensor, |
| | txt_shape: Union[torch.LongTensor, List[torch.LongTensor]], |
| | timestep: Union[int, float, torch.IntTensor, torch.FloatTensor], |
| | disable_cache: bool = False, |
| | ): |
| | cache = Cache(disable=disable_cache) |
| |
|
| | |
| | if isinstance(txt, list): |
| | assert isinstance(self.txt_in, nn.ModuleList) |
| | txt = [ |
| | na.unflatten(fc(i), s) for fc, i, s in zip(self.txt_in, txt, txt_shape) |
| | ] |
| | txt, txt_shape = na.flatten([torch.cat(t, dim=0) for t in zip(*txt)]) |
| | txt = slice_inputs(txt, dim=0) |
| | else: |
| | txt = slice_inputs(txt, dim=0) |
| | txt = self.txt_in(txt) |
| |
|
| | |
| | |
| | vid, vid_shape = self.vid_in(vid, vid_shape, cache) |
| |
|
| | |
| | emb = self.emb_in(timestep, device=vid.device, dtype=vid.dtype) |
| |
|
| | |
| | for i, block in enumerate(self.blocks): |
| | vid, txt, vid_shape, txt_shape = gradient_checkpointing( |
| | enabled=(self.gradient_checkpointing and self.training), |
| | module=block, |
| | vid=vid, |
| | txt=txt, |
| | vid_shape=vid_shape, |
| | txt_shape=txt_shape, |
| | emb=emb, |
| | cache=cache, |
| | ) |
| |
|
| | |
| | if self.vid_out_norm: |
| | vid = self.vid_out_norm(vid) |
| | vid = self.vid_out_ada( |
| | vid, |
| | emb=emb, |
| | layer="out", |
| | mode="in", |
| | hid_len=cache("vid_len", lambda: vid_shape.prod(-1)), |
| | cache=cache, |
| | branch_tag="vid", |
| | ) |
| |
|
| | |
| | vid, vid_shape = self.vid_out(vid, vid_shape, cache) |
| | return NaDiTOutput(vid_sample=vid) |
| |
|