| | from typing import Any, List, Tuple, Optional, Union, Dict |
| | from einops import rearrange |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| |
|
| | from diffusers.models import ModelMixin |
| | from diffusers.configuration_utils import ConfigMixin, register_to_config |
| |
|
| | from .activation_layers import get_activation_layer |
| | from .norm_layers import get_norm_layer |
| | from .embed_layers import TimestepEmbedder, PatchEmbed, TextProjection |
| | from .attenion import attention, parallel_attention, get_cu_seqlens |
| | from .posemb_layers import apply_rotary_emb |
| | from .mlp_layers import MLP, MLPEmbedder, FinalLayer |
| | from .modulate_layers import ModulateDiT, modulate, apply_gate |
| | from .token_refiner import SingleTokenRefiner |
| |
|
| |
|
| | class MMDoubleStreamBlock(nn.Module): |
| | """ |
| | A multimodal dit block with seperate modulation for |
| | text and image/video, see more details (SD3): https://arxiv.org/abs/2403.03206 |
| | (Flux.1): https://github.com/black-forest-labs/flux |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | hidden_size: int, |
| | heads_num: int, |
| | mlp_width_ratio: float, |
| | mlp_act_type: str = "gelu_tanh", |
| | qk_norm: bool = True, |
| | qk_norm_type: str = "rms", |
| | qkv_bias: bool = False, |
| | dtype: Optional[torch.dtype] = None, |
| | device: Optional[torch.device] = None, |
| | ): |
| | factory_kwargs = {"device": device, "dtype": dtype} |
| | super().__init__() |
| |
|
| | self.deterministic = False |
| | self.heads_num = heads_num |
| | head_dim = hidden_size // heads_num |
| | mlp_hidden_dim = int(hidden_size * mlp_width_ratio) |
| |
|
| | self.img_mod = ModulateDiT( |
| | hidden_size, |
| | factor=6, |
| | act_layer=get_activation_layer("silu"), |
| | **factory_kwargs, |
| | ) |
| | self.img_norm1 = nn.LayerNorm( |
| | hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs |
| | ) |
| |
|
| | self.img_attn_qkv = nn.Linear( |
| | hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs |
| | ) |
| | qk_norm_layer = get_norm_layer(qk_norm_type) |
| | self.img_attn_q_norm = ( |
| | qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) |
| | if qk_norm |
| | else nn.Identity() |
| | ) |
| | self.img_attn_k_norm = ( |
| | qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) |
| | if qk_norm |
| | else nn.Identity() |
| | ) |
| | self.img_attn_proj = nn.Linear( |
| | hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs |
| | ) |
| |
|
| | self.img_norm2 = nn.LayerNorm( |
| | hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs |
| | ) |
| | self.img_mlp = MLP( |
| | hidden_size, |
| | mlp_hidden_dim, |
| | act_layer=get_activation_layer(mlp_act_type), |
| | bias=True, |
| | **factory_kwargs, |
| | ) |
| |
|
| | self.txt_mod = ModulateDiT( |
| | hidden_size, |
| | factor=6, |
| | act_layer=get_activation_layer("silu"), |
| | **factory_kwargs, |
| | ) |
| | self.txt_norm1 = nn.LayerNorm( |
| | hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs |
| | ) |
| |
|
| | self.txt_attn_qkv = nn.Linear( |
| | hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs |
| | ) |
| | self.txt_attn_q_norm = ( |
| | qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) |
| | if qk_norm |
| | else nn.Identity() |
| | ) |
| | self.txt_attn_k_norm = ( |
| | qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) |
| | if qk_norm |
| | else nn.Identity() |
| | ) |
| | self.txt_attn_proj = nn.Linear( |
| | hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs |
| | ) |
| |
|
| | self.txt_norm2 = nn.LayerNorm( |
| | hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs |
| | ) |
| | self.txt_mlp = MLP( |
| | hidden_size, |
| | mlp_hidden_dim, |
| | act_layer=get_activation_layer(mlp_act_type), |
| | bias=True, |
| | **factory_kwargs, |
| | ) |
| | self.hybrid_seq_parallel_attn = None |
| |
|
| | def enable_deterministic(self): |
| | self.deterministic = True |
| |
|
| | def disable_deterministic(self): |
| | self.deterministic = False |
| |
|
| | def forward( |
| | self, |
| | img: torch.Tensor, |
| | txt: torch.Tensor, |
| | vec: torch.Tensor, |
| | cu_seqlens_q: Optional[torch.Tensor] = None, |
| | cu_seqlens_kv: Optional[torch.Tensor] = None, |
| | max_seqlen_q: Optional[int] = None, |
| | max_seqlen_kv: Optional[int] = None, |
| | freqs_cis: tuple = None, |
| | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | ( |
| | img_mod1_shift, |
| | img_mod1_scale, |
| | img_mod1_gate, |
| | img_mod2_shift, |
| | img_mod2_scale, |
| | img_mod2_gate, |
| | ) = self.img_mod(vec).chunk(6, dim=-1) |
| | ( |
| | txt_mod1_shift, |
| | txt_mod1_scale, |
| | txt_mod1_gate, |
| | txt_mod2_shift, |
| | txt_mod2_scale, |
| | txt_mod2_gate, |
| | ) = self.txt_mod(vec).chunk(6, dim=-1) |
| |
|
| | |
| | img_modulated = self.img_norm1(img) |
| | img_modulated = modulate( |
| | img_modulated, shift=img_mod1_shift, scale=img_mod1_scale |
| | ) |
| | img_qkv = self.img_attn_qkv(img_modulated) |
| | img_q, img_k, img_v = rearrange( |
| | img_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num |
| | ) |
| | |
| | img_q = self.img_attn_q_norm(img_q).to(img_v) |
| | img_k = self.img_attn_k_norm(img_k).to(img_v) |
| |
|
| | |
| | if freqs_cis is not None: |
| | img_qq, img_kk = apply_rotary_emb(img_q, img_k, freqs_cis, head_first=False) |
| | assert ( |
| | img_qq.shape == img_q.shape and img_kk.shape == img_k.shape |
| | ), f"img_kk: {img_qq.shape}, img_q: {img_q.shape}, img_kk: {img_kk.shape}, img_k: {img_k.shape}" |
| | img_q, img_k = img_qq, img_kk |
| |
|
| | |
| | txt_modulated = self.txt_norm1(txt) |
| | txt_modulated = modulate( |
| | txt_modulated, shift=txt_mod1_shift, scale=txt_mod1_scale |
| | ) |
| | txt_qkv = self.txt_attn_qkv(txt_modulated) |
| | txt_q, txt_k, txt_v = rearrange( |
| | txt_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num |
| | ) |
| | |
| | txt_q = self.txt_attn_q_norm(txt_q).to(txt_v) |
| | txt_k = self.txt_attn_k_norm(txt_k).to(txt_v) |
| |
|
| | |
| | q = torch.cat((img_q, txt_q), dim=1) |
| | k = torch.cat((img_k, txt_k), dim=1) |
| | v = torch.cat((img_v, txt_v), dim=1) |
| | assert ( |
| | cu_seqlens_q.shape[0] == 2 * img.shape[0] + 1 |
| | ), f"cu_seqlens_q.shape:{cu_seqlens_q.shape}, img.shape[0]:{img.shape[0]}" |
| | |
| | |
| | if not self.hybrid_seq_parallel_attn: |
| | attn = attention( |
| | q, |
| | k, |
| | v, |
| | cu_seqlens_q=cu_seqlens_q, |
| | cu_seqlens_kv=cu_seqlens_kv, |
| | max_seqlen_q=max_seqlen_q, |
| | max_seqlen_kv=max_seqlen_kv, |
| | batch_size=img_k.shape[0], |
| | ) |
| | else: |
| | attn = parallel_attention( |
| | self.hybrid_seq_parallel_attn, |
| | q, |
| | k, |
| | v, |
| | img_q_len=img_q.shape[1], |
| | img_kv_len=img_k.shape[1], |
| | cu_seqlens_q=cu_seqlens_q, |
| | cu_seqlens_kv=cu_seqlens_kv |
| | ) |
| | |
| | |
| |
|
| | img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1] :] |
| |
|
| | |
| | img = img + apply_gate(self.img_attn_proj(img_attn), gate=img_mod1_gate) |
| | img = img + apply_gate( |
| | self.img_mlp( |
| | modulate( |
| | self.img_norm2(img), shift=img_mod2_shift, scale=img_mod2_scale |
| | ) |
| | ), |
| | gate=img_mod2_gate, |
| | ) |
| |
|
| | |
| | txt = txt + apply_gate(self.txt_attn_proj(txt_attn), gate=txt_mod1_gate) |
| | txt = txt + apply_gate( |
| | self.txt_mlp( |
| | modulate( |
| | self.txt_norm2(txt), shift=txt_mod2_shift, scale=txt_mod2_scale |
| | ) |
| | ), |
| | gate=txt_mod2_gate, |
| | ) |
| |
|
| | return img, txt |
| |
|
| |
|
| | class MMSingleStreamBlock(nn.Module): |
| | """ |
| | A DiT block with parallel linear layers as described in |
| | https://arxiv.org/abs/2302.05442 and adapted modulation interface. |
| | Also refer to (SD3): https://arxiv.org/abs/2403.03206 |
| | (Flux.1): https://github.com/black-forest-labs/flux |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | hidden_size: int, |
| | heads_num: int, |
| | mlp_width_ratio: float = 4.0, |
| | mlp_act_type: str = "gelu_tanh", |
| | qk_norm: bool = True, |
| | qk_norm_type: str = "rms", |
| | qk_scale: float = None, |
| | dtype: Optional[torch.dtype] = None, |
| | device: Optional[torch.device] = None, |
| | ): |
| | factory_kwargs = {"device": device, "dtype": dtype} |
| | super().__init__() |
| |
|
| | self.deterministic = False |
| | self.hidden_size = hidden_size |
| | self.heads_num = heads_num |
| | head_dim = hidden_size // heads_num |
| | mlp_hidden_dim = int(hidden_size * mlp_width_ratio) |
| | self.mlp_hidden_dim = mlp_hidden_dim |
| | self.scale = qk_scale or head_dim ** -0.5 |
| |
|
| | |
| | self.linear1 = nn.Linear( |
| | hidden_size, hidden_size * 3 + mlp_hidden_dim, **factory_kwargs |
| | ) |
| | |
| | self.linear2 = nn.Linear( |
| | hidden_size + mlp_hidden_dim, hidden_size, **factory_kwargs |
| | ) |
| |
|
| | qk_norm_layer = get_norm_layer(qk_norm_type) |
| | self.q_norm = ( |
| | qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) |
| | if qk_norm |
| | else nn.Identity() |
| | ) |
| | self.k_norm = ( |
| | qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) |
| | if qk_norm |
| | else nn.Identity() |
| | ) |
| |
|
| | self.pre_norm = nn.LayerNorm( |
| | hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs |
| | ) |
| |
|
| | self.mlp_act = get_activation_layer(mlp_act_type)() |
| | self.modulation = ModulateDiT( |
| | hidden_size, |
| | factor=3, |
| | act_layer=get_activation_layer("silu"), |
| | **factory_kwargs, |
| | ) |
| | self.hybrid_seq_parallel_attn = None |
| |
|
| | def enable_deterministic(self): |
| | self.deterministic = True |
| |
|
| | def disable_deterministic(self): |
| | self.deterministic = False |
| |
|
| | def forward( |
| | self, |
| | x: torch.Tensor, |
| | vec: torch.Tensor, |
| | txt_len: int, |
| | cu_seqlens_q: Optional[torch.Tensor] = None, |
| | cu_seqlens_kv: Optional[torch.Tensor] = None, |
| | max_seqlen_q: Optional[int] = None, |
| | max_seqlen_kv: Optional[int] = None, |
| | freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None, |
| | ) -> torch.Tensor: |
| | mod_shift, mod_scale, mod_gate = self.modulation(vec).chunk(3, dim=-1) |
| | x_mod = modulate(self.pre_norm(x), shift=mod_shift, scale=mod_scale) |
| | qkv, mlp = torch.split( |
| | self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1 |
| | ) |
| |
|
| | q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num) |
| |
|
| | |
| | q = self.q_norm(q).to(v) |
| | k = self.k_norm(k).to(v) |
| |
|
| | |
| | if freqs_cis is not None: |
| | img_q, txt_q = q[:, :-txt_len, :, :], q[:, -txt_len:, :, :] |
| | img_k, txt_k = k[:, :-txt_len, :, :], k[:, -txt_len:, :, :] |
| | img_qq, img_kk = apply_rotary_emb(img_q, img_k, freqs_cis, head_first=False) |
| | assert ( |
| | img_qq.shape == img_q.shape and img_kk.shape == img_k.shape |
| | ), f"img_kk: {img_qq.shape}, img_q: {img_q.shape}, img_kk: {img_kk.shape}, img_k: {img_k.shape}" |
| | img_q, img_k = img_qq, img_kk |
| | q = torch.cat((img_q, txt_q), dim=1) |
| | k = torch.cat((img_k, txt_k), dim=1) |
| |
|
| | |
| | assert ( |
| | cu_seqlens_q.shape[0] == 2 * x.shape[0] + 1 |
| | ), f"cu_seqlens_q.shape:{cu_seqlens_q.shape}, x.shape[0]:{x.shape[0]}" |
| | |
| | |
| | if not self.hybrid_seq_parallel_attn: |
| | attn = attention( |
| | q, |
| | k, |
| | v, |
| | cu_seqlens_q=cu_seqlens_q, |
| | cu_seqlens_kv=cu_seqlens_kv, |
| | max_seqlen_q=max_seqlen_q, |
| | max_seqlen_kv=max_seqlen_kv, |
| | batch_size=x.shape[0], |
| | ) |
| | else: |
| | attn = parallel_attention( |
| | self.hybrid_seq_parallel_attn, |
| | q, |
| | k, |
| | v, |
| | img_q_len=img_q.shape[1], |
| | img_kv_len=img_k.shape[1], |
| | cu_seqlens_q=cu_seqlens_q, |
| | cu_seqlens_kv=cu_seqlens_kv |
| | ) |
| | |
| |
|
| | |
| | output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) |
| | return x + apply_gate(output, gate=mod_gate) |
| |
|
| |
|
| | class HYVideoDiffusionTransformer(ModelMixin, ConfigMixin): |
| | """ |
| | HunyuanVideo Transformer backbone |
| | |
| | Inherited from ModelMixin and ConfigMixin for compatibility with diffusers' sampler StableDiffusionPipeline. |
| | |
| | Reference: |
| | [1] Flux.1: https://github.com/black-forest-labs/flux |
| | [2] MMDiT: http://arxiv.org/abs/2403.03206 |
| | |
| | Parameters |
| | ---------- |
| | args: argparse.Namespace |
| | The arguments parsed by argparse. |
| | patch_size: list |
| | The size of the patch. |
| | in_channels: int |
| | The number of input channels. |
| | out_channels: int |
| | The number of output channels. |
| | hidden_size: int |
| | The hidden size of the transformer backbone. |
| | heads_num: int |
| | The number of attention heads. |
| | mlp_width_ratio: float |
| | The ratio of the hidden size of the MLP in the transformer block. |
| | mlp_act_type: str |
| | The activation function of the MLP in the transformer block. |
| | depth_double_blocks: int |
| | The number of transformer blocks in the double blocks. |
| | depth_single_blocks: int |
| | The number of transformer blocks in the single blocks. |
| | rope_dim_list: list |
| | The dimension of the rotary embedding for t, h, w. |
| | qkv_bias: bool |
| | Whether to use bias in the qkv linear layer. |
| | qk_norm: bool |
| | Whether to use qk norm. |
| | qk_norm_type: str |
| | The type of qk norm. |
| | guidance_embed: bool |
| | Whether to use guidance embedding for distillation. |
| | text_projection: str |
| | The type of the text projection, default is single_refiner. |
| | use_attention_mask: bool |
| | Whether to use attention mask for text encoder. |
| | dtype: torch.dtype |
| | The dtype of the model. |
| | device: torch.device |
| | The device of the model. |
| | """ |
| |
|
| | @register_to_config |
| | def __init__( |
| | self, |
| | args: Any, |
| | patch_size: list = [1, 2, 2], |
| | in_channels: int = 4, |
| | out_channels: int = None, |
| | hidden_size: int = 3072, |
| | heads_num: int = 24, |
| | mlp_width_ratio: float = 4.0, |
| | mlp_act_type: str = "gelu_tanh", |
| | mm_double_blocks_depth: int = 20, |
| | mm_single_blocks_depth: int = 40, |
| | rope_dim_list: List[int] = [16, 56, 56], |
| | qkv_bias: bool = True, |
| | qk_norm: bool = True, |
| | qk_norm_type: str = "rms", |
| | guidance_embed: bool = False, |
| | text_projection: str = "single_refiner", |
| | use_attention_mask: bool = True, |
| | dtype: Optional[torch.dtype] = None, |
| | device: Optional[torch.device] = None, |
| | ): |
| | factory_kwargs = {"device": device, "dtype": dtype} |
| | super().__init__() |
| |
|
| | self.patch_size = patch_size |
| | self.in_channels = in_channels |
| | self.out_channels = in_channels if out_channels is None else out_channels |
| | self.unpatchify_channels = self.out_channels |
| | self.guidance_embed = guidance_embed |
| | self.rope_dim_list = rope_dim_list |
| |
|
| | |
| | |
| | self.use_attention_mask = use_attention_mask |
| | self.text_projection = text_projection |
| |
|
| | self.text_states_dim = args.text_states_dim |
| | self.text_states_dim_2 = args.text_states_dim_2 |
| |
|
| | if hidden_size % heads_num != 0: |
| | raise ValueError( |
| | f"Hidden size {hidden_size} must be divisible by heads_num {heads_num}" |
| | ) |
| | pe_dim = hidden_size // heads_num |
| | if sum(rope_dim_list) != pe_dim: |
| | raise ValueError( |
| | f"Got {rope_dim_list} but expected positional dim {pe_dim}" |
| | ) |
| | self.hidden_size = hidden_size |
| | self.heads_num = heads_num |
| |
|
| | |
| | self.img_in = PatchEmbed( |
| | self.patch_size, self.in_channels, self.hidden_size, **factory_kwargs |
| | ) |
| |
|
| | |
| | if self.text_projection == "linear": |
| | self.txt_in = TextProjection( |
| | self.text_states_dim, |
| | self.hidden_size, |
| | get_activation_layer("silu"), |
| | **factory_kwargs, |
| | ) |
| | elif self.text_projection == "single_refiner": |
| | self.txt_in = SingleTokenRefiner( |
| | self.text_states_dim, hidden_size, heads_num, depth=2, **factory_kwargs |
| | ) |
| | else: |
| | raise NotImplementedError( |
| | f"Unsupported text_projection: {self.text_projection}" |
| | ) |
| |
|
| | |
| | self.time_in = TimestepEmbedder( |
| | self.hidden_size, get_activation_layer("silu"), **factory_kwargs |
| | ) |
| |
|
| | |
| | self.vector_in = MLPEmbedder( |
| | self.text_states_dim_2, self.hidden_size, **factory_kwargs |
| | ) |
| |
|
| | |
| | self.guidance_in = ( |
| | TimestepEmbedder( |
| | self.hidden_size, get_activation_layer("silu"), **factory_kwargs |
| | ) |
| | if guidance_embed |
| | else None |
| | ) |
| |
|
| | |
| | self.double_blocks = nn.ModuleList( |
| | [ |
| | MMDoubleStreamBlock( |
| | self.hidden_size, |
| | self.heads_num, |
| | mlp_width_ratio=mlp_width_ratio, |
| | mlp_act_type=mlp_act_type, |
| | qk_norm=qk_norm, |
| | qk_norm_type=qk_norm_type, |
| | qkv_bias=qkv_bias, |
| | **factory_kwargs, |
| | ) |
| | for _ in range(mm_double_blocks_depth) |
| | ] |
| | ) |
| |
|
| | |
| | self.single_blocks = nn.ModuleList( |
| | [ |
| | MMSingleStreamBlock( |
| | self.hidden_size, |
| | self.heads_num, |
| | mlp_width_ratio=mlp_width_ratio, |
| | mlp_act_type=mlp_act_type, |
| | qk_norm=qk_norm, |
| | qk_norm_type=qk_norm_type, |
| | **factory_kwargs, |
| | ) |
| | for _ in range(mm_single_blocks_depth) |
| | ] |
| | ) |
| |
|
| | self.final_layer = FinalLayer( |
| | self.hidden_size, |
| | self.patch_size, |
| | self.out_channels, |
| | get_activation_layer("silu"), |
| | **factory_kwargs, |
| | ) |
| |
|
| | def enable_deterministic(self): |
| | for block in self.double_blocks: |
| | block.enable_deterministic() |
| | for block in self.single_blocks: |
| | block.enable_deterministic() |
| |
|
| | def disable_deterministic(self): |
| | for block in self.double_blocks: |
| | block.disable_deterministic() |
| | for block in self.single_blocks: |
| | block.disable_deterministic() |
| |
|
| | def forward( |
| | self, |
| | x: torch.Tensor, |
| | t: torch.Tensor, |
| | text_states: torch.Tensor = None, |
| | text_mask: torch.Tensor = None, |
| | text_states_2: Optional[torch.Tensor] = None, |
| | freqs_cos: Optional[torch.Tensor] = None, |
| | freqs_sin: Optional[torch.Tensor] = None, |
| | guidance: torch.Tensor = None, |
| | return_dict: bool = True, |
| | ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: |
| | out = {} |
| | img = x |
| | txt = text_states |
| | _, _, ot, oh, ow = x.shape |
| | tt, th, tw = ( |
| | ot // self.patch_size[0], |
| | oh // self.patch_size[1], |
| | ow // self.patch_size[2], |
| | ) |
| |
|
| | |
| | vec = self.time_in(t) |
| |
|
| | |
| | vec = vec + self.vector_in(text_states_2) |
| |
|
| | |
| | if self.guidance_embed: |
| | if guidance is None: |
| | raise ValueError( |
| | "Didn't get guidance strength for guidance distilled model." |
| | ) |
| |
|
| | |
| | vec = vec + self.guidance_in(guidance) |
| |
|
| | |
| | img = self.img_in(img) |
| | if self.text_projection == "linear": |
| | txt = self.txt_in(txt) |
| | elif self.text_projection == "single_refiner": |
| | txt = self.txt_in(txt, t, text_mask if self.use_attention_mask else None) |
| | else: |
| | raise NotImplementedError( |
| | f"Unsupported text_projection: {self.text_projection}" |
| | ) |
| |
|
| | txt_seq_len = txt.shape[1] |
| | img_seq_len = img.shape[1] |
| |
|
| | |
| | cu_seqlens_q = get_cu_seqlens(text_mask, img_seq_len) |
| | cu_seqlens_kv = cu_seqlens_q |
| | max_seqlen_q = img_seq_len + txt_seq_len |
| | max_seqlen_kv = max_seqlen_q |
| |
|
| | freqs_cis = (freqs_cos, freqs_sin) if freqs_cos is not None else None |
| | |
| | for _, block in enumerate(self.double_blocks): |
| | double_block_args = [ |
| | img, |
| | txt, |
| | vec, |
| | cu_seqlens_q, |
| | cu_seqlens_kv, |
| | max_seqlen_q, |
| | max_seqlen_kv, |
| | freqs_cis, |
| | ] |
| |
|
| | img, txt = block(*double_block_args) |
| |
|
| | |
| | x = torch.cat((img, txt), 1) |
| | if len(self.single_blocks) > 0: |
| | for _, block in enumerate(self.single_blocks): |
| | single_block_args = [ |
| | x, |
| | vec, |
| | txt_seq_len, |
| | cu_seqlens_q, |
| | cu_seqlens_kv, |
| | max_seqlen_q, |
| | max_seqlen_kv, |
| | (freqs_cos, freqs_sin), |
| | ] |
| |
|
| | x = block(*single_block_args) |
| |
|
| | img = x[:, :img_seq_len, ...] |
| |
|
| | |
| | img = self.final_layer(img, vec) |
| |
|
| | img = self.unpatchify(img, tt, th, tw) |
| | if return_dict: |
| | out["x"] = img |
| | return out |
| | return img |
| |
|
| | def unpatchify(self, x, t, h, w): |
| | """ |
| | x: (N, T, patch_size**2 * C) |
| | imgs: (N, H, W, C) |
| | """ |
| | c = self.unpatchify_channels |
| | pt, ph, pw = self.patch_size |
| | assert t * h * w == x.shape[1] |
| |
|
| | x = x.reshape(shape=(x.shape[0], t, h, w, c, pt, ph, pw)) |
| | x = torch.einsum("nthwcopq->nctohpwq", x) |
| | imgs = x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw)) |
| |
|
| | return imgs |
| |
|
| | def params_count(self): |
| | counts = { |
| | "double": sum( |
| | [ |
| | sum(p.numel() for p in block.img_attn_qkv.parameters()) |
| | + sum(p.numel() for p in block.img_attn_proj.parameters()) |
| | + sum(p.numel() for p in block.img_mlp.parameters()) |
| | + sum(p.numel() for p in block.txt_attn_qkv.parameters()) |
| | + sum(p.numel() for p in block.txt_attn_proj.parameters()) |
| | + sum(p.numel() for p in block.txt_mlp.parameters()) |
| | for block in self.double_blocks |
| | ] |
| | ), |
| | "single": sum( |
| | [ |
| | sum(p.numel() for p in block.linear1.parameters()) |
| | + sum(p.numel() for p in block.linear2.parameters()) |
| | for block in self.single_blocks |
| | ] |
| | ), |
| | "total": sum(p.numel() for p in self.parameters()), |
| | } |
| | counts["attn+mlp"] = counts["double"] + counts["single"] |
| | return counts |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | HUNYUAN_VIDEO_CONFIG = { |
| | "HYVideo-T/2": { |
| | "mm_double_blocks_depth": 20, |
| | "mm_single_blocks_depth": 40, |
| | "rope_dim_list": [16, 56, 56], |
| | "hidden_size": 3072, |
| | "heads_num": 24, |
| | "mlp_width_ratio": 4, |
| | }, |
| | "HYVideo-T/2-cfgdistill": { |
| | "mm_double_blocks_depth": 20, |
| | "mm_single_blocks_depth": 40, |
| | "rope_dim_list": [16, 56, 56], |
| | "hidden_size": 3072, |
| | "heads_num": 24, |
| | "mlp_width_ratio": 4, |
| | "guidance_embed": True, |
| | }, |
| | } |
| |
|