import torch import torch.nn as nn import numpy as np from diffusers.models import ModelMixin from typing import Optional, Tuple, Union import torch.nn.functional as F from diffusers.models.attention_processor import Attention from einops import rearrange def get_1d_rotary_pos_embed( dim: int, pos: Union[np.ndarray, int], theta: float = 10000.0, use_real=False, linear_factor=1.0, ntk_factor=1.0, repeat_interleave_real=True, freqs_dtype=torch.float32, # torch.float32, torch.float64 (flux) ): """ Precompute the frequency tensor for complex exponentials (cis) with given dimensions. This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64 data type. Args: dim (`int`): Dimension of the frequency tensor. pos (`np.ndarray` or `int`): Position indices for the frequency tensor. [S] or scalar theta (`float`, *optional*, defaults to 10000.0): Scaling factor for frequency computation. Defaults to 10000.0. use_real (`bool`, *optional*): If True, return real part and imaginary part separately. Otherwise, return complex numbers. linear_factor (`float`, *optional*, defaults to 1.0): Scaling factor for the context extrapolation. Defaults to 1.0. ntk_factor (`float`, *optional*, defaults to 1.0): Scaling factor for the NTK-Aware RoPE. Defaults to 1.0. repeat_interleave_real (`bool`, *optional*, defaults to `True`): If `True` and `use_real`, real part and imaginary part are each interleaved with themselves to reach `dim`. Otherwise, they are concateanted with themselves. freqs_dtype (`torch.float32` or `torch.float64`, *optional*, defaults to `torch.float32`): the dtype of the frequency tensor. Returns: `torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2] """ assert dim % 2 == 0 if isinstance(pos, int): pos = torch.arange(pos) if isinstance(pos, np.ndarray): pos = torch.from_numpy(pos) # type: ignore # [S] theta = theta * ntk_factor freqs = ( 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device) / dim)) / linear_factor ) # [D/2] freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2] is_npu = freqs.device.type == "npu" if is_npu: freqs = freqs.float() if use_real and repeat_interleave_real: # flux, hunyuan-dit, cogvideox freqs_cos = freqs.cos().repeat_interleave(2, dim=1, output_size=freqs.shape[1] * 2).float() # [S, D] freqs_sin = freqs.sin().repeat_interleave(2, dim=1, output_size=freqs.shape[1] * 2).float() # [S, D] return freqs_cos, freqs_sin elif use_real: # stable audio, allegro freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D] freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D] return freqs_cos, freqs_sin else: # lumina freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2] return freqs_cis class WanRotaryPosEmbed(nn.Module): def __init__( self, attention_head_dim: int, patch_size: Tuple[int, int, int], max_seq_len: int, theta: float = 10000.0 ): super().__init__() self.attention_head_dim = attention_head_dim self.patch_size = patch_size self.max_seq_len = max_seq_len h_dim = w_dim = 2 * (attention_head_dim // 6) t_dim = attention_head_dim - h_dim - w_dim freqs = [] for dim in [t_dim, h_dim, w_dim]: freq = get_1d_rotary_pos_embed( dim, max_seq_len, theta, use_real=False, repeat_interleave_real=False, freqs_dtype=torch.float64 ) freqs.append(freq) self.freqs = torch.cat(freqs, dim=1) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: batch_size, num_channels, num_frames, height, width = hidden_states.shape p_t, p_h, p_w = self.patch_size ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w self.freqs = self.freqs.to(hidden_states.device) freqs = self.freqs.split_with_sizes( [ self.attention_head_dim // 2 - 2 * (self.attention_head_dim // 6), self.attention_head_dim // 6, self.attention_head_dim // 6, ], dim=1, ) freqs_f = freqs[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1) freqs_h = freqs[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1) freqs_w = freqs[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1) freqs = torch.cat([freqs_f, freqs_h, freqs_w], dim=-1).reshape(1, 1, ppf * pph * ppw, -1) return freqs from ..wanvideo.modules.attention import sageattn_func def zero_module(module): # Zero out the parameters of a module and return it. for p in module.parameters(): p.detach().zero_() return module class SimpleAttnProcessor2_0: def __init__(self, attention_mode): self.attention_mode = attention_mode def __call__( self, attn: Attention, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, rotary_emb: Optional[torch.Tensor] = None, **kwargs ) -> torch.Tensor: query = attn.to_q(hidden_states) key = attn.to_k(hidden_states) value = attn.to_v(hidden_states) if attn.norm_q is not None: query = attn.norm_q(query) if attn.norm_k is not None: key = attn.norm_k(key) query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2) key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2) value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2) # [b,head,l,c] if rotary_emb is not None: def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor): x_rotated = torch.view_as_complex(hidden_states.to(torch.float64).unflatten(3, (-1, 2))) x_out = torch.view_as_real(x_rotated * freqs).flatten(3, 4) return x_out.type_as(hidden_states) query = apply_rotary_emb(query, rotary_emb) key = apply_rotary_emb(key, rotary_emb) if self.attention_mode == 'sdpa': hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) elif self.attention_mode == 'sageattn': hidden_states = sageattn_func( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) hidden_states = hidden_states.transpose(1, 2).flatten(2, 3) hidden_states = hidden_states.type_as(query) hidden_states = attn.to_out[0](hidden_states) hidden_states = attn.to_out[1](hidden_states) return hidden_states class SimpleCogVideoXLayerNormZero(nn.Module): def __init__( self, conditioning_dim: int, embedding_dim: int, elementwise_affine: bool = True, eps: float = 1e-5, bias: bool = True, ) -> None: super().__init__() self.silu = nn.SiLU() self.linear = nn.Linear(conditioning_dim, 3 * embedding_dim, bias=bias) self.norm = nn.LayerNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine) def forward(self, hidden_states: torch.Tensor, temb: torch.Tensor): shift, scale, gate = self.linear(self.silu(temb)).chunk(3, dim=1) hidden_states = self.norm(hidden_states) * (1 + scale)[:, None, :] + shift[:, None, :] return hidden_states, gate[:, None, :] class SingleAttentionBlock(nn.Module): def __init__( self, dim, ffn_dim, num_heads, time_embed_dim=512, qk_norm="rms_norm_across_heads", eps=1e-6, attention_mode="sdpa", ): super().__init__() self.dim = dim self.ffn_dim = ffn_dim self.num_heads = num_heads self.qk_norm = qk_norm self.eps = eps # layers self.norm1 = SimpleCogVideoXLayerNormZero( time_embed_dim, dim, elementwise_affine=True, eps=1e-5, bias=True ) self.self_attn = Attention( query_dim=dim, heads=num_heads, kv_heads=num_heads, dim_head=dim // num_heads, qk_norm=qk_norm, eps=eps, bias=True, cross_attention_dim=None, out_bias=True, processor=SimpleAttnProcessor2_0(attention_mode), ) self.norm2 = SimpleCogVideoXLayerNormZero( time_embed_dim, dim, elementwise_affine=True, eps=1e-5, bias=True ) self.ffn = nn.Sequential( nn.Linear(dim, ffn_dim), nn.GELU(approximate='tanh'), nn.Linear(ffn_dim, dim) ) def forward( self, hidden_states, temb, rotary_emb, ): # norm & modulate norm_hidden_states, gate_msa = self.norm1(hidden_states, temb) # attention attn_hidden_states = self.self_attn(hidden_states=norm_hidden_states, rotary_emb=rotary_emb) hidden_states = hidden_states + gate_msa * attn_hidden_states # norm & modulate norm_hidden_states, gate_ff = self.norm2(hidden_states, temb) # feed-forward ff_output = self.ffn(norm_hidden_states) hidden_states = hidden_states + gate_ff * ff_output return hidden_states class MaskCamEmbed(nn.Module): def __init__(self, controlnet_cfg) -> None: super().__init__() # padding bug fixed if controlnet_cfg.get("interp", False): self.mask_padding = [0, 0, 0, 0, 3, 3] # 左右上下前后, I2V-interp,首尾帧 else: self.mask_padding = [0, 0, 0, 0, 3, 0] # 左右上下前后, I2V add_channels = controlnet_cfg.get("add_channels", 1) mid_channels = controlnet_cfg.get("mid_channels", 64) self.mask_proj = nn.Sequential(nn.Conv3d(add_channels, mid_channels, kernel_size=(4, 8, 8), stride=(4, 8, 8)), nn.GroupNorm(mid_channels // 8, mid_channels), nn.SiLU()) self.mask_zero_proj = zero_module(nn.Conv3d(mid_channels, controlnet_cfg["conv_out_dim"], kernel_size=(1, 2, 2), stride=(1, 2, 2))) def forward(self, add_inputs: torch.Tensor): # render_mask.shape [b,c,f,h,w] warp_add_pad = F.pad(add_inputs, self.mask_padding, mode="constant", value=0) add_embeds = self.mask_proj(warp_add_pad) # [B,C,F,H,W] add_embeds = self.mask_zero_proj(add_embeds) add_embeds = rearrange(add_embeds, "b c f h w -> b (f h w) c") return add_embeds class WanControlNet(ModelMixin): def __init__(self, controlnet_cfg): super().__init__() self.rope_max_seq_len = 1024 self.patch_size = (1, 2, 2) self.in_channels = controlnet_cfg["in_channels"] self.dim = controlnet_cfg["dim"] self.num_heads = controlnet_cfg["num_heads"] self.quantized = controlnet_cfg["quantized"] self.base_dtype = controlnet_cfg["base_dtype"] if controlnet_cfg["conv_out_dim"] != controlnet_cfg["dim"]: self.proj_in = nn.Linear(controlnet_cfg["conv_out_dim"], controlnet_cfg["dim"]) else: self.proj_in = nn.Identity() self.controlnet_blocks = nn.ModuleList( [ SingleAttentionBlock( dim=self.dim, ffn_dim=controlnet_cfg["ffn_dim"], num_heads=self.num_heads, time_embed_dim=controlnet_cfg["time_embed_dim"], qk_norm="rms_norm_across_heads", attention_mode=controlnet_cfg["attention_mode"], ) for _ in range(controlnet_cfg["num_layers"]) ] ) self.proj_out = nn.ModuleList( [ zero_module(nn.Linear(self.dim, 5120)) for _ in range(controlnet_cfg["num_layers"]) ] ) self.gradient_checkpointing = False self.controlnet_rope = WanRotaryPosEmbed(self.dim // self.num_heads, self.patch_size, self.rope_max_seq_len) self.controlnet_patch_embedding = nn.Conv3d( self.in_channels, controlnet_cfg["conv_out_dim"], kernel_size=self.patch_size, stride=self.patch_size, dtype=torch.float32 ) self.controlnet_mask_embedding = MaskCamEmbed(controlnet_cfg) def forward(self, render_latent, render_mask, camera_embedding, temb, device): controlnet_rotary_emb = self.controlnet_rope(render_latent) controlnet_inputs = self.controlnet_patch_embedding(render_latent.to(torch.float32)) if not self.quantized: controlnet_inputs = controlnet_inputs.to(render_latent.dtype) else: controlnet_inputs = controlnet_inputs.to(self.base_dtype) controlnet_inputs = controlnet_inputs.flatten(2).transpose(1, 2) # additional inputs (mask, camera embedding) add_inputs = None if camera_embedding is not None and render_mask is not None: add_inputs = torch.cat([render_mask, camera_embedding], dim=1) elif render_mask is not None: add_inputs = render_mask if add_inputs is not None: add_inputs = self.controlnet_mask_embedding(add_inputs) controlnet_inputs = controlnet_inputs + add_inputs hidden_states = self.proj_in(controlnet_inputs) controlnet_states = [] for i, block in enumerate(self.controlnet_blocks): hidden_states = block( hidden_states=hidden_states, temb=temb, rotary_emb=controlnet_rotary_emb ) controlnet_states.append(self.proj_out[i](hidden_states).to(device)) return controlnet_states