| |
| |
| import math |
|
|
| import torch |
| import torch.nn as nn |
| from einops import repeat |
|
|
| from comfy.ldm.modules.attention import optimized_attention |
| from comfy.ldm.flux.layers import EmbedND |
| from comfy.ldm.flux.math import apply_rope |
| import comfy.ldm.common_dit |
| import comfy.model_management |
|
|
|
|
| def sinusoidal_embedding_1d(dim, position): |
| |
| assert dim % 2 == 0 |
| half = dim // 2 |
| position = position.type(torch.float32) |
|
|
| |
| sinusoid = torch.outer( |
| position, torch.pow(10000, -torch.arange(half).to(position).div(half))) |
| x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1) |
| return x |
|
|
|
|
| class WanSelfAttention(nn.Module): |
|
|
| def __init__(self, |
| dim, |
| num_heads, |
| window_size=(-1, -1), |
| qk_norm=True, |
| eps=1e-6, operation_settings={}): |
| assert dim % num_heads == 0 |
| super().__init__() |
| self.dim = dim |
| self.num_heads = num_heads |
| self.head_dim = dim // num_heads |
| self.window_size = window_size |
| self.qk_norm = qk_norm |
| self.eps = eps |
|
|
| |
| self.q = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) |
| self.k = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) |
| self.v = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) |
| self.o = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) |
| self.norm_q = operation_settings.get("operations").RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity() |
| self.norm_k = operation_settings.get("operations").RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity() |
|
|
| def forward(self, x, freqs): |
| r""" |
| Args: |
| x(Tensor): Shape [B, L, num_heads, C / num_heads] |
| freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2] |
| """ |
| b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim |
|
|
| |
| def qkv_fn(x): |
| q = self.norm_q(self.q(x)).view(b, s, n, d) |
| k = self.norm_k(self.k(x)).view(b, s, n, d) |
| v = self.v(x).view(b, s, n * d) |
| return q, k, v |
|
|
| q, k, v = qkv_fn(x) |
| q, k = apply_rope(q, k, freqs) |
|
|
| x = optimized_attention( |
| q.view(b, s, n * d), |
| k.view(b, s, n * d), |
| v, |
| heads=self.num_heads, |
| ) |
|
|
| x = self.o(x) |
| return x |
|
|
|
|
| class WanT2VCrossAttention(WanSelfAttention): |
|
|
| def forward(self, x, context, **kwargs): |
| r""" |
| Args: |
| x(Tensor): Shape [B, L1, C] |
| context(Tensor): Shape [B, L2, C] |
| """ |
| |
| q = self.norm_q(self.q(x)) |
| k = self.norm_k(self.k(context)) |
| v = self.v(context) |
|
|
| |
| x = optimized_attention(q, k, v, heads=self.num_heads) |
|
|
| x = self.o(x) |
| return x |
|
|
|
|
| class WanI2VCrossAttention(WanSelfAttention): |
|
|
| def __init__(self, |
| dim, |
| num_heads, |
| window_size=(-1, -1), |
| qk_norm=True, |
| eps=1e-6, operation_settings={}): |
| super().__init__(dim, num_heads, window_size, qk_norm, eps, operation_settings=operation_settings) |
|
|
| self.k_img = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) |
| self.v_img = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) |
| |
| self.norm_k_img = operation_settings.get("operations").RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity() |
|
|
| def forward(self, x, context, context_img_len): |
| r""" |
| Args: |
| x(Tensor): Shape [B, L1, C] |
| context(Tensor): Shape [B, L2, C] |
| """ |
| context_img = context[:, :context_img_len] |
| context = context[:, context_img_len:] |
|
|
| |
| q = self.norm_q(self.q(x)) |
| k = self.norm_k(self.k(context)) |
| v = self.v(context) |
| k_img = self.norm_k_img(self.k_img(context_img)) |
| v_img = self.v_img(context_img) |
| img_x = optimized_attention(q, k_img, v_img, heads=self.num_heads) |
| |
| x = optimized_attention(q, k, v, heads=self.num_heads) |
|
|
| |
| x = x + img_x |
| x = self.o(x) |
| return x |
|
|
|
|
| WAN_CROSSATTENTION_CLASSES = { |
| 't2v_cross_attn': WanT2VCrossAttention, |
| 'i2v_cross_attn': WanI2VCrossAttention, |
| } |
|
|
|
|
| def repeat_e(e, x): |
| repeats = 1 |
| if e.shape[1] > 1: |
| repeats = x.shape[1] // e.shape[1] |
| if repeats == 1: |
| return e |
| return torch.repeat_interleave(e, repeats, dim=1) |
|
|
|
|
| class WanAttentionBlock(nn.Module): |
|
|
| def __init__(self, |
| cross_attn_type, |
| dim, |
| ffn_dim, |
| num_heads, |
| window_size=(-1, -1), |
| qk_norm=True, |
| cross_attn_norm=False, |
| eps=1e-6, operation_settings={}): |
| super().__init__() |
| self.dim = dim |
| self.ffn_dim = ffn_dim |
| self.num_heads = num_heads |
| self.window_size = window_size |
| self.qk_norm = qk_norm |
| self.cross_attn_norm = cross_attn_norm |
| self.eps = eps |
|
|
| |
| self.norm1 = operation_settings.get("operations").LayerNorm(dim, eps, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) |
| self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm, |
| eps, operation_settings=operation_settings) |
| self.norm3 = operation_settings.get("operations").LayerNorm( |
| dim, eps, |
| elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if cross_attn_norm else nn.Identity() |
| self.cross_attn = WAN_CROSSATTENTION_CLASSES[cross_attn_type](dim, |
| num_heads, |
| (-1, -1), |
| qk_norm, |
| eps, operation_settings=operation_settings) |
| self.norm2 = operation_settings.get("operations").LayerNorm(dim, eps, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) |
| self.ffn = nn.Sequential( |
| operation_settings.get("operations").Linear(dim, ffn_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")), nn.GELU(approximate='tanh'), |
| operation_settings.get("operations").Linear(ffn_dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))) |
|
|
| |
| self.modulation = nn.Parameter(torch.empty(1, 6, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))) |
|
|
| def forward( |
| self, |
| x, |
| e, |
| freqs, |
| context, |
| context_img_len=257, |
| ): |
| r""" |
| Args: |
| x(Tensor): Shape [B, L, C] |
| e(Tensor): Shape [B, 6, C] |
| freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2] |
| """ |
| |
|
|
| if e.ndim < 4: |
| e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e).chunk(6, dim=1) |
| else: |
| e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device).unsqueeze(0) + e).unbind(2) |
| |
|
|
| |
| y = self.self_attn( |
| self.norm1(x) * (1 + repeat_e(e[1], x)) + repeat_e(e[0], x), |
| freqs) |
|
|
| x = x + y * repeat_e(e[2], x) |
|
|
| |
| x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len) |
| y = self.ffn(self.norm2(x) * (1 + repeat_e(e[4], x)) + repeat_e(e[3], x)) |
| x = x + y * repeat_e(e[5], x) |
| return x |
|
|
|
|
| class VaceWanAttentionBlock(WanAttentionBlock): |
| def __init__( |
| self, |
| cross_attn_type, |
| dim, |
| ffn_dim, |
| num_heads, |
| window_size=(-1, -1), |
| qk_norm=True, |
| cross_attn_norm=False, |
| eps=1e-6, |
| block_id=0, |
| operation_settings={} |
| ): |
| super().__init__(cross_attn_type, dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps, operation_settings=operation_settings) |
| self.block_id = block_id |
| if block_id == 0: |
| self.before_proj = operation_settings.get("operations").Linear(self.dim, self.dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) |
| self.after_proj = operation_settings.get("operations").Linear(self.dim, self.dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) |
|
|
| def forward(self, c, x, **kwargs): |
| if self.block_id == 0: |
| c = self.before_proj(c) + x |
| c = super().forward(c, **kwargs) |
| c_skip = self.after_proj(c) |
| return c_skip, c |
|
|
|
|
| class WanCamAdapter(nn.Module): |
| def __init__(self, in_dim, out_dim, kernel_size, stride, num_residual_blocks=1, operation_settings={}): |
| super(WanCamAdapter, self).__init__() |
|
|
| |
| self.pixel_unshuffle = nn.PixelUnshuffle(downscale_factor=8) |
|
|
| |
| |
| self.conv = operation_settings.get("operations").Conv2d(in_dim * 64, out_dim, kernel_size=kernel_size, stride=stride, padding=0, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) |
|
|
| |
| self.residual_blocks = nn.Sequential( |
| *[WanCamResidualBlock(out_dim, operation_settings = operation_settings) for _ in range(num_residual_blocks)] |
| ) |
|
|
| def forward(self, x): |
| |
| bs, c, f, h, w = x.size() |
| x = x.permute(0, 2, 1, 3, 4).contiguous().view(bs * f, c, h, w) |
|
|
| |
| x_unshuffled = self.pixel_unshuffle(x) |
|
|
| |
| x_conv = self.conv(x_unshuffled) |
|
|
| |
| out = self.residual_blocks(x_conv) |
|
|
| |
| out = out.view(bs, f, out.size(1), out.size(2), out.size(3)) |
|
|
| |
| out = out.permute(0, 2, 1, 3, 4) |
|
|
| return out |
|
|
|
|
| class WanCamResidualBlock(nn.Module): |
| def __init__(self, dim, operation_settings={}): |
| super(WanCamResidualBlock, self).__init__() |
| self.conv1 = operation_settings.get("operations").Conv2d(dim, dim, kernel_size=3, padding=1, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) |
| self.relu = nn.ReLU(inplace=True) |
| self.conv2 = operation_settings.get("operations").Conv2d(dim, dim, kernel_size=3, padding=1, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) |
|
|
| def forward(self, x): |
| residual = x |
| out = self.relu(self.conv1(x)) |
| out = self.conv2(out) |
| out += residual |
| return out |
|
|
|
|
| class Head(nn.Module): |
|
|
| def __init__(self, dim, out_dim, patch_size, eps=1e-6, operation_settings={}): |
| super().__init__() |
| self.dim = dim |
| self.out_dim = out_dim |
| self.patch_size = patch_size |
| self.eps = eps |
|
|
| |
| out_dim = math.prod(patch_size) * out_dim |
| self.norm = operation_settings.get("operations").LayerNorm(dim, eps, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) |
| self.head = operation_settings.get("operations").Linear(dim, out_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) |
|
|
| |
| self.modulation = nn.Parameter(torch.empty(1, 2, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))) |
|
|
| def forward(self, x, e): |
| r""" |
| Args: |
| x(Tensor): Shape [B, L1, C] |
| e(Tensor): Shape [B, C] |
| """ |
| |
| if e.ndim < 3: |
| e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e.unsqueeze(1)).chunk(2, dim=1) |
| else: |
| e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device).unsqueeze(0) + e.unsqueeze(2)).unbind(2) |
|
|
| x = (self.head(self.norm(x) * (1 + repeat_e(e[1], x)) + repeat_e(e[0], x))) |
| return x |
|
|
|
|
| class MLPProj(torch.nn.Module): |
|
|
| def __init__(self, in_dim, out_dim, flf_pos_embed_token_number=None, operation_settings={}): |
| super().__init__() |
|
|
| self.proj = torch.nn.Sequential( |
| operation_settings.get("operations").LayerNorm(in_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")), operation_settings.get("operations").Linear(in_dim, in_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")), |
| torch.nn.GELU(), operation_settings.get("operations").Linear(in_dim, out_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")), |
| operation_settings.get("operations").LayerNorm(out_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))) |
|
|
| if flf_pos_embed_token_number is not None: |
| self.emb_pos = nn.Parameter(torch.empty((1, flf_pos_embed_token_number, in_dim), device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))) |
| else: |
| self.emb_pos = None |
|
|
| def forward(self, image_embeds): |
| if self.emb_pos is not None: |
| image_embeds = image_embeds[:, :self.emb_pos.shape[1]] + comfy.model_management.cast_to(self.emb_pos[:, :image_embeds.shape[1]], dtype=image_embeds.dtype, device=image_embeds.device) |
|
|
| clip_extra_context_tokens = self.proj(image_embeds) |
| return clip_extra_context_tokens |
|
|
|
|
| class WanModel(torch.nn.Module): |
| r""" |
| Wan diffusion backbone supporting both text-to-video and image-to-video. |
| """ |
|
|
| def __init__(self, |
| model_type='t2v', |
| patch_size=(1, 2, 2), |
| text_len=512, |
| in_dim=16, |
| dim=2048, |
| ffn_dim=8192, |
| freq_dim=256, |
| text_dim=4096, |
| out_dim=16, |
| num_heads=16, |
| num_layers=32, |
| window_size=(-1, -1), |
| qk_norm=True, |
| cross_attn_norm=True, |
| eps=1e-6, |
| flf_pos_embed_token_number=None, |
| image_model=None, |
| device=None, |
| dtype=None, |
| operations=None, |
| ): |
| r""" |
| Initialize the diffusion model backbone. |
| |
| Args: |
| model_type (`str`, *optional*, defaults to 't2v'): |
| Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video) |
| patch_size (`tuple`, *optional*, defaults to (1, 2, 2)): |
| 3D patch dimensions for video embedding (t_patch, h_patch, w_patch) |
| text_len (`int`, *optional*, defaults to 512): |
| Fixed length for text embeddings |
| in_dim (`int`, *optional*, defaults to 16): |
| Input video channels (C_in) |
| dim (`int`, *optional*, defaults to 2048): |
| Hidden dimension of the transformer |
| ffn_dim (`int`, *optional*, defaults to 8192): |
| Intermediate dimension in feed-forward network |
| freq_dim (`int`, *optional*, defaults to 256): |
| Dimension for sinusoidal time embeddings |
| text_dim (`int`, *optional*, defaults to 4096): |
| Input dimension for text embeddings |
| out_dim (`int`, *optional*, defaults to 16): |
| Output video channels (C_out) |
| num_heads (`int`, *optional*, defaults to 16): |
| Number of attention heads |
| num_layers (`int`, *optional*, defaults to 32): |
| Number of transformer blocks |
| window_size (`tuple`, *optional*, defaults to (-1, -1)): |
| Window size for local attention (-1 indicates global attention) |
| qk_norm (`bool`, *optional*, defaults to True): |
| Enable query/key normalization |
| cross_attn_norm (`bool`, *optional*, defaults to False): |
| Enable cross-attention normalization |
| eps (`float`, *optional*, defaults to 1e-6): |
| Epsilon value for normalization layers |
| """ |
|
|
| super().__init__() |
| self.dtype = dtype |
| operation_settings = {"operations": operations, "device": device, "dtype": dtype} |
|
|
| assert model_type in ['t2v', 'i2v'] |
| self.model_type = model_type |
|
|
| self.patch_size = patch_size |
| self.text_len = text_len |
| self.in_dim = in_dim |
| self.dim = dim |
| self.ffn_dim = ffn_dim |
| self.freq_dim = freq_dim |
| self.text_dim = text_dim |
| self.out_dim = out_dim |
| self.num_heads = num_heads |
| self.num_layers = num_layers |
| self.window_size = window_size |
| self.qk_norm = qk_norm |
| self.cross_attn_norm = cross_attn_norm |
| self.eps = eps |
|
|
| |
| self.patch_embedding = operations.Conv3d( |
| in_dim, dim, kernel_size=patch_size, stride=patch_size, device=operation_settings.get("device"), dtype=torch.float32) |
| self.text_embedding = nn.Sequential( |
| operations.Linear(text_dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")), nn.GELU(approximate='tanh'), |
| operations.Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))) |
|
|
| self.time_embedding = nn.Sequential( |
| operations.Linear(freq_dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")), nn.SiLU(), operations.Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))) |
| self.time_projection = nn.Sequential(nn.SiLU(), operations.Linear(dim, dim * 6, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))) |
|
|
| |
| cross_attn_type = 't2v_cross_attn' if model_type == 't2v' else 'i2v_cross_attn' |
| self.blocks = nn.ModuleList([ |
| WanAttentionBlock(cross_attn_type, dim, ffn_dim, num_heads, |
| window_size, qk_norm, cross_attn_norm, eps, operation_settings=operation_settings) |
| for _ in range(num_layers) |
| ]) |
|
|
| |
| self.head = Head(dim, out_dim, patch_size, eps, operation_settings=operation_settings) |
|
|
| d = dim // num_heads |
| self.rope_embedder = EmbedND(dim=d, theta=10000.0, axes_dim=[d - 4 * (d // 6), 2 * (d // 6), 2 * (d // 6)]) |
|
|
| if model_type == 'i2v': |
| self.img_emb = MLPProj(1280, dim, flf_pos_embed_token_number=flf_pos_embed_token_number, operation_settings=operation_settings) |
| else: |
| self.img_emb = None |
|
|
| def forward_orig( |
| self, |
| x, |
| t, |
| context, |
| clip_fea=None, |
| freqs=None, |
| transformer_options={}, |
| **kwargs, |
| ): |
| r""" |
| Forward pass through the diffusion model |
| |
| Args: |
| x (Tensor): |
| List of input video tensors with shape [B, C_in, F, H, W] |
| t (Tensor): |
| Diffusion timesteps tensor of shape [B] |
| context (List[Tensor]): |
| List of text embeddings each with shape [B, L, C] |
| seq_len (`int`): |
| Maximum sequence length for positional encoding |
| clip_fea (Tensor, *optional*): |
| CLIP image features for image-to-video mode |
| y (List[Tensor], *optional*): |
| Conditional video inputs for image-to-video mode, same shape as x |
| |
| Returns: |
| List[Tensor]: |
| List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8] |
| """ |
| |
| x = self.patch_embedding(x.float()).to(x.dtype) |
| grid_sizes = x.shape[2:] |
| x = x.flatten(2).transpose(1, 2) |
|
|
| |
| e = self.time_embedding( |
| sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(dtype=x[0].dtype)) |
| e = e.reshape(t.shape[0], -1, e.shape[-1]) |
| e0 = self.time_projection(e).unflatten(2, (6, self.dim)) |
|
|
| |
| context = self.text_embedding(context) |
|
|
| context_img_len = None |
| if clip_fea is not None: |
| if self.img_emb is not None: |
| context_clip = self.img_emb(clip_fea) |
| context = torch.concat([context_clip, context], dim=1) |
| context_img_len = clip_fea.shape[-2] |
|
|
| patches_replace = transformer_options.get("patches_replace", {}) |
| blocks_replace = patches_replace.get("dit", {}) |
| for i, block in enumerate(self.blocks): |
| if ("double_block", i) in blocks_replace: |
| def block_wrap(args): |
| out = {} |
| out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len) |
| return out |
| out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap}) |
| x = out["img"] |
| else: |
| x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len) |
|
|
| |
| x = self.head(x, e) |
|
|
| |
| x = self.unpatchify(x, grid_sizes) |
| return x |
|
|
| def forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, **kwargs): |
| bs, c, t, h, w = x.shape |
| x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size) |
|
|
| patch_size = self.patch_size |
| t_len = ((t + (patch_size[0] // 2)) // patch_size[0]) |
| h_len = ((h + (patch_size[1] // 2)) // patch_size[1]) |
| w_len = ((w + (patch_size[2] // 2)) // patch_size[2]) |
|
|
| if time_dim_concat is not None: |
| time_dim_concat = comfy.ldm.common_dit.pad_to_patch_size(time_dim_concat, self.patch_size) |
| x = torch.cat([x, time_dim_concat], dim=2) |
| t_len = ((x.shape[2] + (patch_size[0] // 2)) // patch_size[0]) |
|
|
| img_ids = torch.zeros((t_len, h_len, w_len, 3), device=x.device, dtype=x.dtype) |
| img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(0, t_len - 1, steps=t_len, device=x.device, dtype=x.dtype).reshape(-1, 1, 1) |
| img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).reshape(1, -1, 1) |
| img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).reshape(1, 1, -1) |
| img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs) |
|
|
| freqs = self.rope_embedder(img_ids).movedim(1, 2) |
| return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, **kwargs)[:, :, :t, :h, :w] |
|
|
| def unpatchify(self, x, grid_sizes): |
| r""" |
| Reconstruct video tensors from patch embeddings. |
| |
| Args: |
| x (List[Tensor]): |
| List of patchified features, each with shape [L, C_out * prod(patch_size)] |
| grid_sizes (Tensor): |
| Original spatial-temporal grid dimensions before patching, |
| shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches) |
| |
| Returns: |
| List[Tensor]: |
| Reconstructed video tensors with shape [L, C_out, F, H / 8, W / 8] |
| """ |
|
|
| c = self.out_dim |
| u = x |
| b = u.shape[0] |
| u = u[:, :math.prod(grid_sizes)].view(b, *grid_sizes, *self.patch_size, c) |
| u = torch.einsum('bfhwpqrc->bcfphqwr', u) |
| u = u.reshape(b, c, *[i * j for i, j in zip(grid_sizes, self.patch_size)]) |
| return u |
|
|
|
|
| class VaceWanModel(WanModel): |
| r""" |
| Wan diffusion backbone supporting both text-to-video and image-to-video. |
| """ |
|
|
| def __init__(self, |
| model_type='vace', |
| patch_size=(1, 2, 2), |
| text_len=512, |
| in_dim=16, |
| dim=2048, |
| ffn_dim=8192, |
| freq_dim=256, |
| text_dim=4096, |
| out_dim=16, |
| num_heads=16, |
| num_layers=32, |
| window_size=(-1, -1), |
| qk_norm=True, |
| cross_attn_norm=True, |
| eps=1e-6, |
| flf_pos_embed_token_number=None, |
| image_model=None, |
| vace_layers=None, |
| vace_in_dim=None, |
| device=None, |
| dtype=None, |
| operations=None, |
| ): |
|
|
| super().__init__(model_type='t2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, image_model=image_model, device=device, dtype=dtype, operations=operations) |
| operation_settings = {"operations": operations, "device": device, "dtype": dtype} |
|
|
| |
| if vace_layers is not None: |
| self.vace_layers = vace_layers |
| self.vace_in_dim = vace_in_dim |
| |
| self.vace_blocks = nn.ModuleList([ |
| VaceWanAttentionBlock('t2v_cross_attn', self.dim, self.ffn_dim, self.num_heads, self.window_size, self.qk_norm, self.cross_attn_norm, self.eps, block_id=i, operation_settings=operation_settings) |
| for i in range(self.vace_layers) |
| ]) |
|
|
| self.vace_layers_mapping = {i: n for n, i in enumerate(range(0, self.num_layers, self.num_layers // self.vace_layers))} |
| |
| self.vace_patch_embedding = operations.Conv3d( |
| self.vace_in_dim, self.dim, kernel_size=self.patch_size, stride=self.patch_size, device=device, dtype=torch.float32 |
| ) |
|
|
| def forward_orig( |
| self, |
| x, |
| t, |
| context, |
| vace_context, |
| vace_strength, |
| clip_fea=None, |
| freqs=None, |
| transformer_options={}, |
| **kwargs, |
| ): |
| |
| x = self.patch_embedding(x.float()).to(x.dtype) |
| grid_sizes = x.shape[2:] |
| x = x.flatten(2).transpose(1, 2) |
|
|
| |
| e = self.time_embedding( |
| sinusoidal_embedding_1d(self.freq_dim, t).to(dtype=x[0].dtype)) |
| e0 = self.time_projection(e).unflatten(1, (6, self.dim)) |
|
|
| |
| context = self.text_embedding(context) |
|
|
| context_img_len = None |
| if clip_fea is not None: |
| if self.img_emb is not None: |
| context_clip = self.img_emb(clip_fea) |
| context = torch.concat([context_clip, context], dim=1) |
| context_img_len = clip_fea.shape[-2] |
|
|
| orig_shape = list(vace_context.shape) |
| vace_context = vace_context.movedim(0, 1).reshape([-1] + orig_shape[2:]) |
| c = self.vace_patch_embedding(vace_context.float()).to(vace_context.dtype) |
| c = c.flatten(2).transpose(1, 2) |
| c = list(c.split(orig_shape[0], dim=0)) |
|
|
| |
| x_orig = x |
|
|
| patches_replace = transformer_options.get("patches_replace", {}) |
| blocks_replace = patches_replace.get("dit", {}) |
| for i, block in enumerate(self.blocks): |
| if ("double_block", i) in blocks_replace: |
| def block_wrap(args): |
| out = {} |
| out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len) |
| return out |
| out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap}) |
| x = out["img"] |
| else: |
| x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len) |
|
|
| ii = self.vace_layers_mapping.get(i, None) |
| if ii is not None: |
| for iii in range(len(c)): |
| c_skip, c[iii] = self.vace_blocks[ii](c[iii], x=x_orig, e=e0, freqs=freqs, context=context, context_img_len=context_img_len) |
| x += c_skip * vace_strength[iii] |
| del c_skip |
| |
| x = self.head(x, e) |
|
|
| |
| x = self.unpatchify(x, grid_sizes) |
| return x |
|
|
| class CameraWanModel(WanModel): |
| r""" |
| Wan diffusion backbone supporting both text-to-video and image-to-video. |
| """ |
|
|
| def __init__(self, |
| model_type='camera', |
| patch_size=(1, 2, 2), |
| text_len=512, |
| in_dim=16, |
| dim=2048, |
| ffn_dim=8192, |
| freq_dim=256, |
| text_dim=4096, |
| out_dim=16, |
| num_heads=16, |
| num_layers=32, |
| window_size=(-1, -1), |
| qk_norm=True, |
| cross_attn_norm=True, |
| eps=1e-6, |
| flf_pos_embed_token_number=None, |
| image_model=None, |
| in_dim_control_adapter=24, |
| device=None, |
| dtype=None, |
| operations=None, |
| ): |
|
|
| super().__init__(model_type='i2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, image_model=image_model, device=device, dtype=dtype, operations=operations) |
| operation_settings = {"operations": operations, "device": device, "dtype": dtype} |
|
|
| self.control_adapter = WanCamAdapter(in_dim_control_adapter, dim, kernel_size=patch_size[1:], stride=patch_size[1:], operation_settings=operation_settings) |
|
|
|
|
| def forward_orig( |
| self, |
| x, |
| t, |
| context, |
| clip_fea=None, |
| freqs=None, |
| camera_conditions = None, |
| transformer_options={}, |
| **kwargs, |
| ): |
| |
| x = self.patch_embedding(x.float()).to(x.dtype) |
| if self.control_adapter is not None and camera_conditions is not None: |
| x = x + self.control_adapter(camera_conditions).to(x.dtype) |
| grid_sizes = x.shape[2:] |
| x = x.flatten(2).transpose(1, 2) |
|
|
| |
| e = self.time_embedding( |
| sinusoidal_embedding_1d(self.freq_dim, t).to(dtype=x[0].dtype)) |
| e0 = self.time_projection(e).unflatten(1, (6, self.dim)) |
|
|
| |
| context = self.text_embedding(context) |
|
|
| context_img_len = None |
| if clip_fea is not None: |
| if self.img_emb is not None: |
| context_clip = self.img_emb(clip_fea) |
| context = torch.concat([context_clip, context], dim=1) |
| context_img_len = clip_fea.shape[-2] |
|
|
| patches_replace = transformer_options.get("patches_replace", {}) |
| blocks_replace = patches_replace.get("dit", {}) |
| for i, block in enumerate(self.blocks): |
| if ("double_block", i) in blocks_replace: |
| def block_wrap(args): |
| out = {} |
| out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len) |
| return out |
| out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap}) |
| x = out["img"] |
| else: |
| x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len) |
|
|
| |
| x = self.head(x, e) |
|
|
| |
| x = self.unpatchify(x, grid_sizes) |
| return x |
|
|