|
|
|
|
|
from typing import Any, Dict, Optional, Tuple, Union |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
from diffusers.configuration_utils import ConfigMixin, register_to_config |
|
|
from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin |
|
|
from diffusers.utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers |
|
|
from diffusers.models.modeling_outputs import Transformer2DModelOutput |
|
|
from diffusers.models.modeling_utils import ModelMixin |
|
|
from diffusers.models.transformers.transformer_wan import ( |
|
|
WanTimeTextImageEmbedding, |
|
|
WanRotaryPosEmbed, |
|
|
WanTransformerBlock |
|
|
) |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
def zero_module(module): |
|
|
for p in module.parameters(): |
|
|
nn.init.zeros_(p) |
|
|
return module |
|
|
|
|
|
|
|
|
class WanControlnet(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): |
|
|
r""" |
|
|
A Controlnet Transformer model for video-like data used in the Wan model. |
|
|
|
|
|
Args: |
|
|
patch_size (`Tuple[int]`, defaults to `(1, 2, 2)`): |
|
|
3D patch dimensions for video embedding (t_patch, h_patch, w_patch). |
|
|
num_attention_heads (`int`, defaults to `40`): |
|
|
Fixed length for text embeddings. |
|
|
attention_head_dim (`int`, defaults to `128`): |
|
|
The number of channels in each head. |
|
|
vae_channels (`int`, defaults to `16`): |
|
|
The number of channels in the vae input. |
|
|
in_channels (`int`, defaults to `16`): |
|
|
The number of channels in the controlnet input. |
|
|
text_dim (`int`, defaults to `512`): |
|
|
Input dimension for text embeddings. |
|
|
freq_dim (`int`, defaults to `256`): |
|
|
Dimension for sinusoidal time embeddings. |
|
|
ffn_dim (`int`, defaults to `13824`): |
|
|
Intermediate dimension in feed-forward network. |
|
|
num_layers (`int`, defaults to `40`): |
|
|
The number of layers of transformer blocks to use. |
|
|
window_size (`Tuple[int]`, defaults to `(-1, -1)`): |
|
|
Window size for local attention (-1 indicates global attention). |
|
|
cross_attn_norm (`bool`, defaults to `True`): |
|
|
Enable cross-attention normalization. |
|
|
qk_norm (`bool`, defaults to `True`): |
|
|
Enable query/key normalization. |
|
|
eps (`float`, defaults to `1e-6`): |
|
|
Epsilon value for normalization layers. |
|
|
add_img_emb (`bool`, defaults to `False`): |
|
|
Whether to use img_emb. |
|
|
added_kv_proj_dim (`int`, *optional*, defaults to `None`): |
|
|
The number of channels to use for the added key and value projections. If `None`, no projection is used. |
|
|
downscale_coef (`int`, *optional*, defaults to `8`): |
|
|
Coeficient for downscale controlnet input video. |
|
|
out_proj_dim (`int`, *optional*, defaults to `128 * 12`): |
|
|
Output projection dimention for last linear layers. |
|
|
""" |
|
|
|
|
|
_supports_gradient_checkpointing = True |
|
|
_skip_layerwise_casting_patterns = ["patch_embedding", "condition_embedder", "norm"] |
|
|
_no_split_modules = ["WanTransformerBlock"] |
|
|
_keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"] |
|
|
_keys_to_ignore_on_load_unexpected = ["norm_added_q"] |
|
|
|
|
|
@register_to_config |
|
|
def __init__( |
|
|
self, |
|
|
patch_size: Tuple[int] = (1, 2, 2), |
|
|
num_attention_heads: int = 40, |
|
|
attention_head_dim: int = 128, |
|
|
in_channels: int = 3, |
|
|
vae_channels: int = 16, |
|
|
text_dim: int = 4096, |
|
|
freq_dim: int = 256, |
|
|
ffn_dim: int = 13824, |
|
|
num_layers: int = 20, |
|
|
cross_attn_norm: bool = True, |
|
|
qk_norm: Optional[str] = "rms_norm_across_heads", |
|
|
eps: float = 1e-6, |
|
|
image_dim: Optional[int] = None, |
|
|
added_kv_proj_dim: Optional[int] = None, |
|
|
rope_max_seq_len: int = 1024, |
|
|
downscale_coef: int = 8, |
|
|
out_proj_dim: int = 128 * 12, |
|
|
) -> None: |
|
|
super().__init__() |
|
|
|
|
|
start_channels = in_channels * (downscale_coef ** 2) |
|
|
input_channels = [start_channels, start_channels // 2, start_channels // 4] |
|
|
|
|
|
self.control_encoder = nn.ModuleList([ |
|
|
|
|
|
nn.Sequential( |
|
|
nn.Conv3d( |
|
|
in_channels, |
|
|
input_channels[0], |
|
|
kernel_size=(3, downscale_coef + 1, downscale_coef + 1), |
|
|
stride=(1, downscale_coef, downscale_coef), |
|
|
padding=(1, downscale_coef // 2, downscale_coef // 2) |
|
|
), |
|
|
nn.GELU(approximate="tanh"), |
|
|
nn.GroupNorm(2, input_channels[0]), |
|
|
), |
|
|
|
|
|
nn.Sequential( |
|
|
nn.Conv3d(input_channels[0], input_channels[1], kernel_size=3, stride=(2, 1, 1), padding=1), |
|
|
nn.GELU(approximate="tanh"), |
|
|
nn.GroupNorm(2, input_channels[1]), |
|
|
), |
|
|
|
|
|
nn.Sequential( |
|
|
nn.Conv3d(input_channels[1], input_channels[2], kernel_size=3, stride=(2, 1, 1), padding=1), |
|
|
nn.GELU(approximate="tanh"), |
|
|
nn.GroupNorm(2, input_channels[2]), |
|
|
) |
|
|
]) |
|
|
|
|
|
inner_dim = num_attention_heads * attention_head_dim |
|
|
|
|
|
|
|
|
self.rope = WanRotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len) |
|
|
self.patch_embedding = nn.Conv3d(vae_channels + input_channels[2], inner_dim, kernel_size=patch_size, stride=patch_size) |
|
|
|
|
|
|
|
|
|
|
|
self.condition_embedder = WanTimeTextImageEmbedding( |
|
|
dim=inner_dim, |
|
|
time_freq_dim=freq_dim, |
|
|
time_proj_dim=inner_dim * 6, |
|
|
text_embed_dim=text_dim, |
|
|
image_embed_dim=image_dim, |
|
|
) |
|
|
|
|
|
self.blocks = nn.ModuleList( |
|
|
[ |
|
|
WanTransformerBlock( |
|
|
inner_dim, ffn_dim, num_attention_heads, qk_norm, cross_attn_norm, eps, added_kv_proj_dim |
|
|
) |
|
|
for _ in range(num_layers) |
|
|
] |
|
|
) |
|
|
|
|
|
|
|
|
self.controlnet_blocks = nn.ModuleList([]) |
|
|
|
|
|
for _ in range(len(self.blocks)): |
|
|
controlnet_block = nn.Linear(inner_dim, out_proj_dim) |
|
|
controlnet_block = zero_module(controlnet_block) |
|
|
self.controlnet_blocks.append(controlnet_block) |
|
|
|
|
|
self.gradient_checkpointing = False |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
hidden_states: torch.Tensor, |
|
|
timestep: torch.LongTensor, |
|
|
encoder_hidden_states: torch.Tensor, |
|
|
controlnet_states: torch.Tensor, |
|
|
encoder_hidden_states_image: Optional[torch.Tensor] = None, |
|
|
return_dict: bool = True, |
|
|
attention_kwargs: Optional[Dict[str, Any]] = None, |
|
|
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: |
|
|
if attention_kwargs is not None: |
|
|
attention_kwargs = attention_kwargs.copy() |
|
|
lora_scale = attention_kwargs.pop("scale", 1.0) |
|
|
else: |
|
|
lora_scale = 1.0 |
|
|
|
|
|
if USE_PEFT_BACKEND: |
|
|
|
|
|
scale_lora_layers(self, lora_scale) |
|
|
else: |
|
|
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: |
|
|
logger.warning( |
|
|
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." |
|
|
) |
|
|
rotary_emb = self.rope(hidden_states) |
|
|
|
|
|
|
|
|
for control_encoder_block in self.control_encoder: |
|
|
controlnet_states = control_encoder_block(controlnet_states) |
|
|
|
|
|
hidden_states = torch.cat([hidden_states, controlnet_states], dim=1) |
|
|
|
|
|
|
|
|
hidden_states = self.patch_embedding(hidden_states) |
|
|
hidden_states = hidden_states.flatten(2).transpose(1, 2) |
|
|
|
|
|
|
|
|
if timestep.ndim == 2: |
|
|
|
|
|
if hidden_states.shape[1] != timestep.shape[1]: |
|
|
timestep = timestep.repeat_interleave(hidden_states.shape[1] // timestep.shape[1], dim=1) |
|
|
ts_seq_len = timestep.shape[1] |
|
|
timestep = timestep.flatten() |
|
|
else: |
|
|
ts_seq_len = None |
|
|
|
|
|
temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder( |
|
|
timestep, encoder_hidden_states, encoder_hidden_states_image, timestep_seq_len=ts_seq_len |
|
|
) |
|
|
if ts_seq_len is not None: |
|
|
|
|
|
timestep_proj = timestep_proj.unflatten(2, (6, -1)) |
|
|
else: |
|
|
|
|
|
timestep_proj = timestep_proj.unflatten(1, (6, -1)) |
|
|
|
|
|
if encoder_hidden_states_image is not None: |
|
|
encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1) |
|
|
|
|
|
|
|
|
controlnet_hidden_states = () |
|
|
if torch.is_grad_enabled() and self.gradient_checkpointing: |
|
|
for block, controlnet_block in zip(self.blocks, self.controlnet_blocks): |
|
|
hidden_states = self._gradient_checkpointing_func( |
|
|
block, hidden_states, encoder_hidden_states, timestep_proj, rotary_emb |
|
|
) |
|
|
controlnet_hidden_states += (controlnet_block(hidden_states),) |
|
|
else: |
|
|
for block, controlnet_block in zip(self.blocks, self.controlnet_blocks): |
|
|
hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb) |
|
|
controlnet_hidden_states += (controlnet_block(hidden_states),) |
|
|
|
|
|
|
|
|
if USE_PEFT_BACKEND: |
|
|
|
|
|
unscale_lora_layers(self, lora_scale) |
|
|
|
|
|
if not return_dict: |
|
|
return (controlnet_hidden_states,) |
|
|
|
|
|
return Transformer2DModelOutput(sample=controlnet_hidden_states) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
parameters = { |
|
|
"added_kv_proj_dim": None, |
|
|
"attention_head_dim": 128, |
|
|
"cross_attn_norm": True, |
|
|
"eps": 1e-06, |
|
|
"ffn_dim": 8960, |
|
|
"freq_dim": 256, |
|
|
"image_dim": None, |
|
|
"in_channels": 3, |
|
|
"num_attention_heads": 12, |
|
|
"num_layers": 2, |
|
|
"patch_size": [1, 2, 2], |
|
|
"qk_norm": "rms_norm_across_heads", |
|
|
"rope_max_seq_len": 1024, |
|
|
"text_dim": 4096, |
|
|
"downscale_coef": 8, |
|
|
"out_proj_dim": 12 * 128, |
|
|
"vae_channels": 16 |
|
|
} |
|
|
controlnet = WanControlnet(**parameters) |
|
|
|
|
|
hidden_states = torch.rand(1, 16, 13, 60, 90) |
|
|
timestep = torch.tensor([1000]).repeat(17550).unsqueeze(0) |
|
|
encoder_hidden_states = torch.rand(1, 512, 4096) |
|
|
controlnet_states = torch.rand(1, 3, 49, 480, 720) |
|
|
|
|
|
controlnet_hidden_states = controlnet( |
|
|
hidden_states=hidden_states, |
|
|
timestep=timestep, |
|
|
encoder_hidden_states=encoder_hidden_states, |
|
|
controlnet_states=controlnet_states, |
|
|
return_dict=False |
|
|
) |
|
|
print("Output states count", len(controlnet_hidden_states[0])) |
|
|
for out_hidden_states in controlnet_hidden_states[0]: |
|
|
print(out_hidden_states.shape) |
|
|
|
|
|
|