| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| from typing import Tuple, Union |
| import torch |
| from einops import rearrange |
| from torch.nn import functional as F |
|
|
| |
| from common.cache import Cache |
| from common.distributed.ops import gather_heads_scatter_seq, gather_seq_scatter_heads_qkv |
|
|
| from .. import na |
| from ..attention import FlashAttentionVarlen |
| from ..blocks.mmdit_window_block import MMWindowAttention, MMWindowTransformerBlock |
| from ..mm import MMArg |
| from ..modulation import ada_layer_type |
| from ..normalization import norm_layer_type |
| from ..rope import NaRotaryEmbedding3d |
| from ..window import get_window_op |
|
|
|
|
| class NaSwinAttention(MMWindowAttention): |
| def __init__( |
| self, |
| vid_dim: int, |
| txt_dim: int, |
| heads: int, |
| head_dim: int, |
| qk_bias: bool, |
| qk_rope: bool, |
| qk_norm: norm_layer_type, |
| qk_norm_eps: float, |
| window: Union[int, Tuple[int, int, int]], |
| window_method: str, |
| shared_qkv: bool, |
| **kwargs, |
| ): |
| super().__init__( |
| vid_dim=vid_dim, |
| txt_dim=txt_dim, |
| heads=heads, |
| head_dim=head_dim, |
| qk_bias=qk_bias, |
| qk_rope=qk_rope, |
| qk_norm=qk_norm, |
| qk_norm_eps=qk_norm_eps, |
| window=window, |
| window_method=window_method, |
| shared_qkv=shared_qkv, |
| ) |
| self.rope = NaRotaryEmbedding3d(dim=head_dim // 2) if qk_rope else None |
| self.attn = FlashAttentionVarlen() |
| self.window_op = get_window_op(window_method) |
|
|
| def forward( |
| self, |
| vid: torch.FloatTensor, |
| txt: torch.FloatTensor, |
| vid_shape: torch.LongTensor, |
| txt_shape: torch.LongTensor, |
| cache: Cache, |
| ) -> Tuple[ |
| torch.FloatTensor, |
| torch.FloatTensor, |
| ]: |
|
|
| vid_qkv, txt_qkv = self.proj_qkv(vid, txt) |
| vid_qkv = gather_seq_scatter_heads_qkv( |
| vid_qkv, |
| seq_dim=0, |
| qkv_shape=vid_shape, |
| cache=cache.namespace("vid"), |
| ) |
| txt_qkv = gather_seq_scatter_heads_qkv( |
| txt_qkv, |
| seq_dim=0, |
| qkv_shape=txt_shape, |
| cache=cache.namespace("txt"), |
| ) |
|
|
| |
| cache_win = cache.namespace(f"{self.window_method}_{self.window}_sd3") |
|
|
| def make_window(x: torch.Tensor): |
| t, h, w, _ = x.shape |
| window_slices = self.window_op((t, h, w), self.window) |
| return [x[st, sh, sw] for (st, sh, sw) in window_slices] |
|
|
| window_partition, window_reverse, window_shape, window_count = cache_win( |
| "win_transform", |
| lambda: na.window_idx(vid_shape, make_window), |
| ) |
| vid_qkv_win = window_partition(vid_qkv) |
|
|
| vid_qkv_win = rearrange(vid_qkv_win, "l (o h d) -> l o h d", o=3, d=self.head_dim) |
| txt_qkv = rearrange(txt_qkv, "l (o h d) -> l o h d", o=3, d=self.head_dim) |
|
|
| vid_q, vid_k, vid_v = vid_qkv_win.unbind(1) |
| txt_q, txt_k, txt_v = txt_qkv.unbind(1) |
|
|
| vid_q, txt_q = self.norm_q(vid_q, txt_q) |
| vid_k, txt_k = self.norm_k(vid_k, txt_k) |
|
|
| txt_len = cache("txt_len", lambda: txt_shape.prod(-1)) |
|
|
| vid_len_win = cache_win("vid_len", lambda: window_shape.prod(-1)) |
| txt_len_win = cache_win("txt_len", lambda: txt_len.repeat_interleave(window_count)) |
| all_len_win = cache_win("all_len", lambda: vid_len_win + txt_len_win) |
| concat_win, unconcat_win = cache_win( |
| "mm_pnp", lambda: na.repeat_concat_idx(vid_len_win, txt_len, window_count) |
| ) |
|
|
| |
| if self.rope: |
| vid_q, vid_k = self.rope(vid_q, vid_k, window_shape, cache_win) |
|
|
| out = self.attn( |
| q=concat_win(vid_q, txt_q).bfloat16(), |
| k=concat_win(vid_k, txt_k).bfloat16(), |
| v=concat_win(vid_v, txt_v).bfloat16(), |
| cu_seqlens_q=cache_win( |
| "vid_seqlens_q", lambda: F.pad(all_len_win.cumsum(0), (1, 0)).int() |
| ), |
| cu_seqlens_k=cache_win( |
| "vid_seqlens_k", lambda: F.pad(all_len_win.cumsum(0), (1, 0)).int() |
| ), |
| max_seqlen_q=cache_win("vid_max_seqlen_q", lambda: all_len_win.max().item()), |
| max_seqlen_k=cache_win("vid_max_seqlen_k", lambda: all_len_win.max().item()), |
| ).type_as(vid_q) |
|
|
| |
| vid_out, txt_out = unconcat_win(out) |
|
|
| vid_out = rearrange(vid_out, "l h d -> l (h d)") |
| txt_out = rearrange(txt_out, "l h d -> l (h d)") |
| vid_out = window_reverse(vid_out) |
|
|
| vid_out = gather_heads_scatter_seq(vid_out, head_dim=1, seq_dim=0) |
| txt_out = gather_heads_scatter_seq(txt_out, head_dim=1, seq_dim=0) |
|
|
| vid_out, txt_out = self.proj_out(vid_out, txt_out) |
|
|
| return vid_out, txt_out |
|
|
|
|
| class NaMMSRTransformerBlock(MMWindowTransformerBlock): |
| def __init__( |
| self, |
| *, |
| vid_dim: int, |
| txt_dim: int, |
| emb_dim: int, |
| heads: int, |
| head_dim: int, |
| expand_ratio: int, |
| norm: norm_layer_type, |
| norm_eps: float, |
| ada: ada_layer_type, |
| qk_bias: bool, |
| qk_rope: bool, |
| qk_norm: norm_layer_type, |
| shared_qkv: bool, |
| shared_mlp: bool, |
| mlp_type: str, |
| **kwargs, |
| ): |
| super().__init__( |
| vid_dim=vid_dim, |
| txt_dim=txt_dim, |
| emb_dim=emb_dim, |
| heads=heads, |
| head_dim=head_dim, |
| expand_ratio=expand_ratio, |
| norm=norm, |
| norm_eps=norm_eps, |
| ada=ada, |
| qk_bias=qk_bias, |
| qk_rope=qk_rope, |
| qk_norm=qk_norm, |
| shared_qkv=shared_qkv, |
| shared_mlp=shared_mlp, |
| mlp_type=mlp_type, |
| **kwargs, |
| ) |
|
|
| self.attn = NaSwinAttention( |
| vid_dim=vid_dim, |
| txt_dim=txt_dim, |
| heads=heads, |
| head_dim=head_dim, |
| qk_bias=qk_bias, |
| qk_rope=qk_rope, |
| qk_norm=qk_norm, |
| qk_norm_eps=norm_eps, |
| shared_qkv=shared_qkv, |
| **kwargs, |
| ) |
|
|
| def forward( |
| self, |
| vid: torch.FloatTensor, |
| txt: torch.FloatTensor, |
| vid_shape: torch.LongTensor, |
| txt_shape: torch.LongTensor, |
| emb: torch.FloatTensor, |
| cache: Cache, |
| ) -> Tuple[ |
| torch.FloatTensor, |
| torch.FloatTensor, |
| torch.LongTensor, |
| torch.LongTensor, |
| ]: |
| hid_len = MMArg( |
| cache("vid_len", lambda: vid_shape.prod(-1)), |
| cache("txt_len", lambda: txt_shape.prod(-1)), |
| ) |
| ada_kwargs = { |
| "emb": emb, |
| "hid_len": hid_len, |
| "cache": cache, |
| "branch_tag": MMArg("vid", "txt"), |
| } |
|
|
| vid_attn, txt_attn = self.attn_norm(vid, txt) |
| vid_attn, txt_attn = self.ada(vid_attn, txt_attn, layer="attn", mode="in", **ada_kwargs) |
| vid_attn, txt_attn = self.attn(vid_attn, txt_attn, vid_shape, txt_shape, cache) |
| vid_attn, txt_attn = self.ada(vid_attn, txt_attn, layer="attn", mode="out", **ada_kwargs) |
| vid_attn, txt_attn = (vid_attn + vid), (txt_attn + txt) |
|
|
| vid_mlp, txt_mlp = self.mlp_norm(vid_attn, txt_attn) |
| vid_mlp, txt_mlp = self.ada(vid_mlp, txt_mlp, layer="mlp", mode="in", **ada_kwargs) |
| vid_mlp, txt_mlp = self.mlp(vid_mlp, txt_mlp) |
| vid_mlp, txt_mlp = self.ada(vid_mlp, txt_mlp, layer="mlp", mode="out", **ada_kwargs) |
| vid_mlp, txt_mlp = (vid_mlp + vid_attn), (txt_mlp + txt_attn) |
|
|
| return vid_mlp, txt_mlp, vid_shape, txt_shape |
|
|