alex
further optimisation
76cd760
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import math
import types
from copy import deepcopy
from einops import rearrange
from typing import List
import numpy as np
import torch
import torch.cuda.amp as amp
import torch.nn as nn
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.modeling_utils import ModelMixin
from diffusers.loaders import PeftAdapterMixin
from ...distributed.sequence_parallel import (
distributed_attention,
gather_forward,
get_rank,
get_world_size,
)
from ..model import (
Head,
WanAttentionBlock,
WanLayerNorm,
WanRMSNorm,
WanModel,
WanSelfAttention,
flash_attention,
rope_params,
sinusoidal_embedding_1d,
rope_apply
)
from .face_blocks import FaceEncoder, FaceAdapter
from .motion_encoder import Generator
class HeadAnimate(Head):
def forward(self, x, e):
"""
Args:
x(Tensor): Shape [B, L1, C]
e(Tensor): Shape [B, L1, C]
"""
assert e.dtype == torch.float32
with amp.autocast(dtype=torch.float32):
e = (self.modulation + e.unsqueeze(1)).chunk(2, dim=1)
x = (self.head(self.norm(x) * (1 + e[1]) + e[0]))
return x
class WanAnimateSelfAttention(WanSelfAttention):
def forward(self, x, seq_lens, grid_sizes, freqs):
"""
Args:
x(Tensor): Shape [B, L, num_heads, C / num_heads]
seq_lens(Tensor): Shape [B]
grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
"""
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
# query, key, value function
def qkv_fn(x):
q = self.norm_q(self.q(x)).view(b, s, n, d)
k = self.norm_k(self.k(x)).view(b, s, n, d)
v = self.v(x).view(b, s, n, d)
return q, k, v
q, k, v = qkv_fn(x)
x = flash_attention(
q=rope_apply(q, grid_sizes, freqs),
k=rope_apply(k, grid_sizes, freqs),
v=v,
k_lens=seq_lens,
window_size=self.window_size)
# output
x = x.flatten(2)
x = self.o(x)
return x
class WanAnimateCrossAttention(WanSelfAttention):
def __init__(
self,
dim,
num_heads,
window_size=(-1, -1),
qk_norm=True,
eps=1e-6,
use_img_emb=True
):
super().__init__(
dim,
num_heads,
window_size,
qk_norm,
eps
)
self.use_img_emb = use_img_emb
if use_img_emb:
self.k_img = nn.Linear(dim, dim)
self.v_img = nn.Linear(dim, dim)
self.norm_k_img = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
def forward(self, x, context, context_lens):
"""
x: [B, L1, C].
context: [B, L2, C].
context_lens: [B].
"""
if self.use_img_emb:
context_img = context[:, :257]
context = context[:, 257:]
else:
context = context
b, n, d = x.size(0), self.num_heads, self.head_dim
# compute query, key, value
q = self.norm_q(self.q(x)).view(b, -1, n, d)
k = self.norm_k(self.k(context)).view(b, -1, n, d)
v = self.v(context).view(b, -1, n, d)
if self.use_img_emb:
k_img = self.norm_k_img(self.k_img(context_img)).view(b, -1, n, d)
v_img = self.v_img(context_img).view(b, -1, n, d)
img_x = flash_attention(q, k_img, v_img, k_lens=None)
# compute attention
x = flash_attention(q, k, v, k_lens=context_lens)
# output
x = x.flatten(2)
if self.use_img_emb:
img_x = img_x.flatten(2)
x = x + img_x
x = self.o(x)
return x
class WanAnimateAttentionBlock(nn.Module):
def __init__(self,
dim,
ffn_dim,
num_heads,
window_size=(-1, -1),
qk_norm=True,
cross_attn_norm=True,
eps=1e-6,
use_img_emb=True):
super().__init__()
self.dim = dim
self.ffn_dim = ffn_dim
self.num_heads = num_heads
self.window_size = window_size
self.qk_norm = qk_norm
self.cross_attn_norm = cross_attn_norm
self.eps = eps
# layers
self.norm1 = WanLayerNorm(dim, eps)
self.self_attn = WanAnimateSelfAttention(dim, num_heads, window_size, qk_norm, eps)
self.norm3 = WanLayerNorm(
dim, eps, elementwise_affine=True
) if cross_attn_norm else nn.Identity()
self.cross_attn = WanAnimateCrossAttention(dim, num_heads, (-1, -1), qk_norm, eps, use_img_emb=use_img_emb)
self.norm2 = WanLayerNorm(dim, eps)
self.ffn = nn.Sequential(
nn.Linear(dim, ffn_dim),
nn.GELU(approximate='tanh'),
nn.Linear(ffn_dim, dim)
)
# modulation
self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim ** 0.5)
def forward(
self,
x,
e,
seq_lens,
grid_sizes,
freqs,
context,
context_lens,
):
"""
Args:
x(Tensor): Shape [B, L, C]
e(Tensor): Shape [B, L1, 6, C]
seq_lens(Tensor): Shape [B], length of each sequence in batch
grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
"""
assert e.dtype == torch.float32
with amp.autocast(dtype=torch.float32):
e = (self.modulation + e).chunk(6, dim=1)
assert e[0].dtype == torch.float32
# self-attention
y = self.self_attn(
self.norm1(x).float() * (1 + e[1]) + e[0], seq_lens, grid_sizes, freqs
)
with amp.autocast(dtype=torch.float32):
x = x + y * e[2]
# cross-attention & ffn function
def cross_attn_ffn(x, context, context_lens, e):
x = x + self.cross_attn(self.norm3(x), context, context_lens)
y = self.ffn(self.norm2(x).float() * (1 + e[4]) + e[3])
with amp.autocast(dtype=torch.float32):
x = x + y * e[5]
return x
x = cross_attn_ffn(x, context, context_lens, e)
return x
class MLPProj(torch.nn.Module):
def __init__(self, in_dim, out_dim):
super().__init__()
self.proj = torch.nn.Sequential(
torch.nn.LayerNorm(in_dim),
torch.nn.Linear(in_dim, in_dim),
torch.nn.GELU(),
torch.nn.Linear(in_dim, out_dim),
torch.nn.LayerNorm(out_dim),
)
def forward(self, image_embeds):
clip_extra_context_tokens = self.proj(image_embeds)
return clip_extra_context_tokens
class WanAnimateModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
_no_split_modules = ['WanAttentionBlock']
@register_to_config
def __init__(self,
patch_size=(1, 2, 2),
text_len=512,
in_dim=36,
dim=5120,
ffn_dim=13824,
freq_dim=256,
text_dim=4096,
out_dim=16,
num_heads=40,
num_layers=40,
window_size=(-1, -1),
qk_norm=True,
cross_attn_norm=True,
eps=1e-6,
motion_encoder_dim=512,
use_context_parallel=False,
use_img_emb=True):
super().__init__()
self.patch_size = patch_size
self.text_len = text_len
self.in_dim = in_dim
self.dim = dim
self.ffn_dim = ffn_dim
self.freq_dim = freq_dim
self.text_dim = text_dim
self.out_dim = out_dim
self.num_heads = num_heads
self.num_layers = num_layers
self.window_size = window_size
self.qk_norm = qk_norm
self.cross_attn_norm = cross_attn_norm
self.eps = eps
self.motion_encoder_dim = motion_encoder_dim
self.use_context_parallel = use_context_parallel
self.use_img_emb = use_img_emb
# embeddings
self.patch_embedding = nn.Conv3d(
in_dim, dim, kernel_size=patch_size, stride=patch_size)
self.pose_patch_embedding = nn.Conv3d(
16, dim, kernel_size=patch_size, stride=patch_size
)
self.text_embedding = nn.Sequential(
nn.Linear(text_dim, dim), nn.GELU(approximate='tanh'),
nn.Linear(dim, dim))
self.time_embedding = nn.Sequential(
nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6))
# blocks
self.blocks = nn.ModuleList([
WanAnimateAttentionBlock(dim, ffn_dim, num_heads, window_size, qk_norm,
cross_attn_norm, eps, use_img_emb) for _ in range(num_layers)
])
# head
self.head = HeadAnimate(dim, out_dim, patch_size, eps)
# buffers (don't use register_buffer otherwise dtype will be changed in to())
assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
d = dim // num_heads
_freqs = torch.cat([
rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
rope_params(1024, 2 * (d // 6))
], dim=1)
self.register_buffer("freqs", _freqs, persistent=False)
self.img_emb = MLPProj(1280, dim)
# initialize weights
self.init_weights()
self.motion_encoder = Generator(size=512, style_dim=512, motion_dim=20)
self.face_adapter = FaceAdapter(
heads_num=self.num_heads,
hidden_dim=self.dim,
num_adapter_layers=self.num_layers // 5,
)
self.face_encoder = FaceEncoder(
in_dim=motion_encoder_dim,
hidden_dim=self.dim,
num_heads=4,
)
def after_patch_embedding(self, x: List[torch.Tensor], pose_latents, face_pixel_values):
pose_latents = [self.pose_patch_embedding(u.unsqueeze(0)) for u in pose_latents]
for x_, pose_latents_ in zip(x, pose_latents):
x_[:, :, 1:] += pose_latents_
b,c,T,h,w = face_pixel_values.shape
face_pixel_values = rearrange(face_pixel_values, "b c t h w -> (b t) c h w")
encode_bs = 8
face_pixel_values_tmp = []
for i in range(math.ceil(face_pixel_values.shape[0]/encode_bs)):
face_pixel_values_tmp.append(self.motion_encoder.get_motion(face_pixel_values[i*encode_bs:(i+1)*encode_bs]))
motion_vec = torch.cat(face_pixel_values_tmp)
motion_vec = rearrange(motion_vec, "(b t) c -> b t c", t=T)
motion_vec = self.face_encoder(motion_vec)
B, L, H, C = motion_vec.shape
pad_face = torch.zeros(B, 1, H, C).type_as(motion_vec)
motion_vec = torch.cat([pad_face, motion_vec], dim=1)
return x, motion_vec
def after_transformer_block(self, block_idx, x, motion_vec, motion_masks=None):
if block_idx % 5 == 0:
adapter_args = [x, motion_vec, motion_masks, self.use_context_parallel]
residual_out = self.face_adapter.fuser_blocks[block_idx // 5](*adapter_args)
x = residual_out + x
return x
def forward(
self,
x,
t,
clip_fea,
context,
seq_len,
y=None,
pose_latents=None,
face_pixel_values=None
):
# params
freqs = self.freqs
if y is not None:
x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
# embeddings
x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
x, motion_vec = self.after_patch_embedding(x, pose_latents, face_pixel_values)
grid_sizes = torch.stack(
[torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
x = [u.flatten(2).transpose(1, 2) for u in x]
seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
assert seq_lens.max() <= seq_len
x = torch.cat([
torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
dim=1) for u in x
])
# time embeddings
with amp.autocast(dtype=torch.float32):
e = self.time_embedding(
sinusoidal_embedding_1d(self.freq_dim, t).float()
)
e0 = self.time_projection(e).unflatten(1, (6, self.dim))
assert e.dtype == torch.float32 and e0.dtype == torch.float32
# context
context_lens = None
context = self.text_embedding(
torch.stack([
torch.cat(
[u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
for u in context
]))
if self.use_img_emb:
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
context = torch.concat([context_clip, context], dim=1)
# arguments
kwargs = dict(
e=e0,
seq_lens=seq_lens,
grid_sizes=grid_sizes,
freqs=freqs,
context=context,
context_lens=context_lens)
if self.use_context_parallel:
x = torch.chunk(x, get_world_size(), dim=1)[get_rank()]
for idx, block in enumerate(self.blocks):
x = block(x, **kwargs)
x = self.after_transformer_block(idx, x, motion_vec)
# head
x = self.head(x, e)
if self.use_context_parallel:
x = gather_forward(x, dim=1)
# unpatchify
x = self.unpatchify(x, grid_sizes)
return [u.float() for u in x]
def unpatchify(self, x, grid_sizes):
r"""
Reconstruct video tensors from patch embeddings.
Args:
x (List[Tensor]):
List of patchified features, each with shape [L, C_out * prod(patch_size)]
grid_sizes (Tensor):
Original spatial-temporal grid dimensions before patching,
shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches)
Returns:
List[Tensor]:
Reconstructed video tensors with shape [C_out, F, H / 8, W / 8]
"""
c = self.out_dim
out = []
for u, v in zip(x, grid_sizes.tolist()):
u = u[:math.prod(v)].view(*v, *self.patch_size, c)
u = torch.einsum('fhwpqrc->cfphqwr', u)
u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)])
out.append(u)
return out
def init_weights(self):
r"""
Initialize model parameters using Xavier initialization.
"""
# basic init
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.zeros_(m.bias)
# init embeddings
nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1))
for m in self.text_embedding.modules():
if isinstance(m, nn.Linear):
nn.init.normal_(m.weight, std=.02)
for m in self.time_embedding.modules():
if isinstance(m, nn.Linear):
nn.init.normal_(m.weight, std=.02)
# init output layer
nn.init.zeros_(self.head.head.weight)