PencilHu's picture
Upload folder using huggingface_hub
85752bc verified
Raw
History Blame Contribute Delete
42.7 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import Tuple, Optional, List, Dict, Sequence
from einops import rearrange
from .utils import hash_state_dict_keys
from .wan_video_camera_controller import SimpleAdapter
try:
import flash_attn_interface
FLASH_ATTN_3_AVAILABLE = True
except ModuleNotFoundError:
FLASH_ATTN_3_AVAILABLE = False
try:
import flash_attn
FLASH_ATTN_2_AVAILABLE = True
except ModuleNotFoundError:
FLASH_ATTN_2_AVAILABLE = False
try:
from sageattention import sageattn
SAGE_ATTN_AVAILABLE = True
except ModuleNotFoundError:
SAGE_ATTN_AVAILABLE = False
print("FLASH_ATTN_3_AVAILABLE ",FLASH_ATTN_3_AVAILABLE)
print("FLASH_ATTN_2_AVAILABLE",FLASH_ATTN_2_AVAILABLE)
try:
from flash_attn_interface import flash_attn_varlen_func
except:
try:
from flash_attn.flash_attn_interface import flash_attn_varlen_func
except Exception as e:
flash_attn_varlen_func = None
# def flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads: int, compatibility_mode=False):
# if compatibility_mode:
# q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
# k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
# v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
# x = F.scaled_dot_product_attention(q, k, v)
# x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
# elif FLASH_ATTN_3_AVAILABLE:
# q = rearrange(q, "b s (n d) -> b s n d", n=num_heads)
# k = rearrange(k, "b s (n d) -> b s n d", n=num_heads)
# v = rearrange(v, "b s (n d) -> b s n d", n=num_heads)
# x = flash_attn_interface.flash_attn_func(q, k, v)
# if isinstance(x,tuple):
# x = x[0]
# x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
# elif FLASH_ATTN_2_AVAILABLE:
# q = rearrange(q, "b s (n d) -> b s n d", n=num_heads)
# k = rearrange(k, "b s (n d) -> b s n d", n=num_heads)
# v = rearrange(v, "b s (n d) -> b s n d", n=num_heads)
# x = flash_attn.flash_attn_func(q, k, v)
# x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
# elif SAGE_ATTN_AVAILABLE:
# q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
# k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
# v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
# x = sageattn(q, k, v)
# x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
# else:
# q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
# k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
# v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
# x = F.scaled_dot_product_attention(q, k, v)
# x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
# return x
def flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads: int, compatibility_mode=False, attn_mask=None, shot_latent_indices=None):
if attn_mask is not None:
q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
x = F.scaled_dot_product_attention(q, k, v, attn_mask = attn_mask)
x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
else:
if shot_latent_indices is not None:
q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
# x = F.scaled_dot_product_attention(q, k, v)
x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
elif compatibility_mode:
q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
x = F.scaled_dot_product_attention(q, k, v)
x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
elif FLASH_ATTN_3_AVAILABLE:
q = rearrange(q, "b s (n d) -> b s n d", n=num_heads)
k = rearrange(k, "b s (n d) -> b s n d", n=num_heads)
v = rearrange(v, "b s (n d) -> b s n d", n=num_heads)
x = flash_attn_interface.flash_attn_func(q, k, v)
if isinstance(x,tuple):
x = x[0]
x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
elif FLASH_ATTN_2_AVAILABLE:
# print("flas_attn_2")
q = rearrange(q, "b s (n d) -> b s n d", n=num_heads)
k = rearrange(k, "b s (n d) -> b s n d", n=num_heads)
v = rearrange(v, "b s (n d) -> b s n d", n=num_heads)
x = flash_attn.flash_attn_func(q, k, v)
x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
elif SAGE_ATTN_AVAILABLE:
q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
x = sageattn(q, k, v)
x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
else:
q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
x = F.scaled_dot_product_attention(q, k, v)
x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
return x
def build_global_reps_from_shots(
K_local_shots: List[torch.Tensor],
V_local_shots: List[torch.Tensor],
g_per: int,
mode: str = "firstk" # "mean" | "firstk" | "linspace"
):
"""
简单的代表池构造:从每个 shot 的本地 K/V 生成若干代表 token,并拼成共享池。
K_local_shots[i]: [Ni, H, D]
返回:
K_global: [G_total, H, D], V_global: [G_total, H, D]
"""
reps_k, reps_v = [], []
S = len(K_local_shots)
if S == 0:
return (torch.empty(0), torch.empty(0))
# g_per = max(1, G // S) if G > 0 else 0
G = g_per * S
for Ki, Vi in zip(K_local_shots, V_local_shots):
Ni = Ki.size(0)
if Ni == 0 or g_per == 0:
continue
if mode == "mean":
idx = torch.linspace(0, Ni - 1, steps=g_per, device=Ki.device).long()
reps_k.append(Ki.index_select(0, idx))
reps_v.append(Vi.index_select(0, idx))
elif mode == "firstk":
take = min(g_per, Ni)
reps_k.append(Ki[:take])
reps_v.append(Vi[:take])
elif mode == "linspace":
idx = torch.linspace(0, Ni - 1, steps=g_per, device=Ki.device).long()
reps_k.append(Ki.index_select(0, idx))
reps_v.append(Vi.index_select(0, idx))
else:
raise ValueError(f"unknown mode {mode}")
if len(reps_k) == 0:
return (torch.empty(0, *K_local_shots[0].shape[1:], device=K_local_shots[0].device, dtype=K_local_shots[0].dtype),
torch.empty(0, *V_local_shots[0].shape[1:], device=V_local_shots[0].device, dtype=V_local_shots[0].dtype))
K_global = torch.cat(reps_k, dim=0)
V_global = torch.cat(reps_v, dim=0)
return K_global, V_global
def build_ID_reps(
IDs_2_shots: Dict[int, List[int]],
K_shots: List[torch.Tensor], # each: [Ni, H, D]
V_shots: List[torch.Tensor], # each: [Ni, H, D]
):
"""
shot_2_IDs:
{
shot_id: [id_shot_id_1, id_shot_id_2, ...] # ✅ 这里的 ID 是“特殊shot”的下标
}
Returns:
shot_id -> {"K": K_id, "V": V_id}
其中 K_id/V_id 是该 shot 关联的所有 ID-shot 的 token 拼起来的结果:
K_id: [sum(N_id), H, D]
V_id: [sum(N_id), H, D]
"""
shot_id_kv = {}
for shot_id, id_shot_ids in shot_2_IDs.items():
reps_k, reps_v = [], []
for id_sid in id_shot_ids:
if id_sid < 0 or id_sid >= len(K_shots):
continue
Ki = K_shots[id_sid] # [Ni, H, D] (ID-shot)
Vi = V_shots[id_sid]
# Ki 可能为空(比如 padding / 没有token)
if Ki is None or Vi is None or Ki.numel() == 0:
continue
reps_k.append(Ki)
reps_v.append(Vi)
if len(reps_k) == 0:
# 没有任何 ID-shot 可用:返回空(保持 dtype/device 一致更安全)
# 如果你希望直接不返回该 shot,也可以改成 `continue`
device = K_shots[0].device
dtype = K_shots[0].dtype
shot_id_kv[shot_id] = {
"K": torch.empty(0, *K_shots[0].shape[1:], device=device, dtype=dtype),
"V": torch.empty(0, *V_shots[0].shape[1:], device=device, dtype=dtype),
}
continue
shot_id_kv[shot_id] = {
"K": torch.cat(reps_k, dim=0),
"V": torch.cat(reps_v, dim=0),
}
return shot_id_kv
def attention_per_batch_with_shots(
q: torch.Tensor, # [b, s, n_heads*head_dim]
k: torch.Tensor, # [b, s, n_heads*head_dim]
v: torch.Tensor, # [b, s, n_heads*head_dim]
shot_latent_indices: Sequence[Sequence[int]],
num_heads: int,
# use_shared_global: bool = True,
per_g: int = 64, ###per_g 的值是每帧latent 的数量, 看wan_video_new.py 的1179行
ID_2_shot=None, ######每个shot 中对应包含的ID是那几个,是一个list[ batch0: [shot0: [0,1], shot1:[2]], batch1:[]]
# G_per_shot: int = 0,
dropout_p: float = 0.0,
causal: bool = False
):
assert q.shape == k.shape == v.shape, "shape wrong in attention_per_batch_with_shots"
b, s_tot, hd = q.shape
assert hd % num_heads == 0
d = hd // num_heads
dtype = q.dtype
device = q.device
### Attention: 这里是全部的query !!!!
q = rearrange(q, "b s (n d) -> b n s d", n=num_heads).contiguous()
k = rearrange(k, "b s (n d) -> b n s d", n=num_heads).contiguous()
v = rearrange(v, "b s (n d) -> b n s d", n=num_heads).contiguous()
outputs = []
if flash_attn_varlen_func is None:
raise RuntimeError("flash_attn_varlen_func not available. Please install flash-attn v2+.")
for bi in range(b):
cuts = list(shot_latent_indices[bi])
assert cuts[0] == 0 and cuts[-1] == s_tot, "shot_latent_indices must start with 0 and end with s_tot"
Q_shots, K_shots, V_shots = [], [], []
N_list = []
for a, bnd in zip(cuts[:-1], cuts[1:]):
Q_shots.append(q[bi, :, a:bnd, :]) # [n, Ni, d]
K_shots.append(k[bi, :, a:bnd, :])
V_shots.append(v[bi, :, a:bnd, :])
N_list.append(bnd - a)
Q_locals = [rearrange(Qi, "n s d -> s n d") for Qi in Q_shots]
K_locals = [rearrange(Ki, "n s d -> s n d") for Ki in K_shots]
V_locals = [rearrange(Vi, "n s d -> s n d") for Vi in V_shots]
K_list = []
V_list = []
kv_lengths = []
ids_for_batch = None
if ID_2_shot is not None and bi < len(ID_2_shot):
ids_for_batch = ID_2_shot[bi]
if ids_for_batch:
pre_id_token_num = per_g * 3 # 每个ID 3张图像
shot_token_all_num = cuts[-1]
for shot_id in range(len(K_locals)):
id_list = ids_for_batch[shot_id] if shot_id < len(ids_for_batch) else []
extra_k = []
extra_v = []
for id_idx in id_list:
start = shot_token_all_num + id_idx * pre_id_token_num
if start >= k.shape[2]:
continue
end = min(start + pre_id_token_num, k.shape[2])
id_token_k = k[bi, :, start:end, :]
id_token_v = v[bi, :, start:end, :]
id_token_k = rearrange(id_token_k, "n s d -> s n d")
id_token_v = rearrange(id_token_v, "n s d -> s n d")
extra_k.append(id_token_k)
extra_v.append(id_token_v)
if extra_k:
extra_k = torch.cat(extra_k, dim=0)
extra_v = torch.cat(extra_v, dim=0)
K_list.append(torch.cat([K_locals[shot_id], extra_k], dim=0))
V_list.append(torch.cat([V_locals[shot_id], extra_v], dim=0))
kv_lengths.append(N_list[shot_id] + extra_k.shape[0])
else:
K_list.append(K_locals[shot_id])
V_list.append(V_locals[shot_id])
kv_lengths.append(N_list[shot_id])
else:
K_list = K_locals
V_list = V_locals
kv_lengths = N_list
Q_packed = torch.cat(Q_locals, dim=0) # [sum_N, n, d]
K_packed = torch.cat(K_list, dim=0) # [sum_(N+extra), n, d]
V_packed = torch.cat(V_list, dim=0) # [sum_(N+extra), n, d]
Sshots = len(N_list)
q_seqlens = torch.tensor([0] + [sum(N_list[:i+1]) for i in range(Sshots)],
device=device, dtype=torch.int32)
kv_seqlens = torch.tensor([0] + [sum(kv_lengths[:i+1]) for i in range(Sshots)],
device=device, dtype=torch.int32)
max_q_seqlen = max(N_list) if len(N_list) > 0 else 0
max_kv_seqlen = max(kv_lengths) if len(kv_lengths) > 0 else 0
O_packed = flash_attn_varlen_func(
Q_packed, K_packed, V_packed,
q_seqlens, kv_seqlens,
max_q_seqlen, max_kv_seqlen,
softmax_scale=None, causal=causal
) # [sum_N, n, d]
O_list = []
for i in range(Sshots):
st = q_seqlens[i].item()
ed = q_seqlens[i+1].item()
Oi = O_packed[st:ed] # [Ni, n, d]
O_list.append(Oi)
O_local = torch.cat(O_list, dim=0) # [s_tot, n, d]
O_local = rearrange(O_local, "s n d -> n s d").contiguous() # [n, s, d]
outputs.append(O_local)
x = torch.stack(outputs, dim=0) # [b, n, s, d]
x = rearrange(x, "b n s d -> b s (n d)")
return x
def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor):
return (x * (1 + scale) + shift)
def sinusoidal_embedding_1d(dim, position):
sinusoid = torch.outer(position.type(torch.float64), torch.pow(
10000, -torch.arange(dim//2, dtype=torch.float64, device=position.device).div(dim//2)))
x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
return x.to(position.dtype)
def precompute_freqs_cis_4d(dim: int, end: int = 1024, theta: float = 10000.0):
### shot 的频率要不要和f h w 不一样????
s_freqs_cis = precompute_freqs_cis(dim - 3 * (dim // 4), end, theta)
f_freqs_cis = precompute_freqs_cis(dim // 4, end, theta)
h_freqs_cis = precompute_freqs_cis(dim // 4, end, theta)
w_freqs_cis = precompute_freqs_cis(dim // 4, end, theta)
return s_freqs_cis, f_freqs_cis, h_freqs_cis, w_freqs_cis
def precompute_freqs_cis_3d(dim: int, end: int = 1024, theta: float = 10000.0):
# 3d rope precompute
f_freqs_cis = precompute_freqs_cis(dim - 2 * (dim // 3), end, theta)
h_freqs_cis = precompute_freqs_cis(dim // 3, end, theta)
w_freqs_cis = precompute_freqs_cis(dim // 3, end, theta)
return f_freqs_cis, h_freqs_cis, w_freqs_cis
def precompute_freqs_cis(dim: int, end: int = 1024, theta: float = 10000.0):
# 1d rope precompute
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)
[: (dim // 2)].double() / dim))
freqs = torch.outer(torch.arange(end, device=freqs.device), freqs)
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
return freqs_cis
def rope_apply(x, freqs, num_heads):
x = rearrange(x, "b s (n d) -> b s n d", n=num_heads)
x_out = torch.view_as_complex(x.to(torch.float64).reshape(
x.shape[0], x.shape[1], x.shape[2], -1, 2))
x_out = torch.view_as_real(x_out * freqs).flatten(2)
return x_out.to(x.dtype)
class RMSNorm(nn.Module):
def __init__(self, dim, eps=1e-5):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
def forward(self, x):
dtype = x.dtype
return self.norm(x.float()).to(dtype) * self.weight
class AttentionModule(nn.Module):
def __init__(self, num_heads):
super().__init__()
self.num_heads = num_heads
def forward(self, q, k, v, attn_mask=None, shot_latent_indices = None, per_g=0, ID_2_shot=None):
if attn_mask is not None:
x = flash_attention(q=q, k=k, v=v, num_heads=self.num_heads, attn_mask=attn_mask)
elif shot_latent_indices is not None:
x = attention_per_batch_with_shots(q=q, k=k, v=v, shot_latent_indices=shot_latent_indices, num_heads=self.num_heads, per_g=per_g, ID_2_shot=ID_2_shot)
else:
x = flash_attention(q=q, k=k, v=v, num_heads=self.num_heads)
return x
class SelfAttention(nn.Module):
def __init__(self, dim: int, num_heads: int, eps: float = 1e-6):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.q = nn.Linear(dim, dim)
self.k = nn.Linear(dim, dim)
self.v = nn.Linear(dim, dim)
self.o = nn.Linear(dim, dim)
self.norm_q = RMSNorm(dim, eps=eps)
self.norm_k = RMSNorm(dim, eps=eps)
self.attn = AttentionModule(self.num_heads)
def forward(self, x, freqs, shot_latent_indices=None, per_g=0, ID_2_shot=None):
q = self.norm_q(self.q(x))
k = self.norm_k(self.k(x))
v = self.v(x)
q = rope_apply(q, freqs, self.num_heads)
k = rope_apply(k, freqs, self.num_heads)
x = self.attn(q, k, v, shot_latent_indices=shot_latent_indices, per_g=per_g, ID_2_shot=ID_2_shot)
return self.o(x)
class CrossAttention(nn.Module):
def __init__(self, dim: int, num_heads: int, eps: float = 1e-6, has_image_input: bool = False):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.q = nn.Linear(dim, dim)
self.k = nn.Linear(dim, dim)
self.v = nn.Linear(dim, dim)
self.o = nn.Linear(dim, dim)
self.norm_q = RMSNorm(dim, eps=eps)
self.norm_k = RMSNorm(dim, eps=eps)
self.has_image_input = has_image_input
if has_image_input:
self.k_img = nn.Linear(dim, dim)
self.v_img = nn.Linear(dim, dim)
self.norm_k_img = RMSNorm(dim, eps=eps)
self.attn = AttentionModule(self.num_heads)
def forward(self, x: torch.Tensor, y: torch.Tensor, attn_mask=None):
if self.has_image_input:
img = y[:, :257]
ctx = y[:, 257:]
else:
ctx = y
q = self.norm_q(self.q(x))
k = self.norm_k(self.k(ctx))
v = self.v(ctx)
x = self.attn(q, k, v, attn_mask=attn_mask)
if self.has_image_input:
k_img = self.norm_k_img(self.k_img(img))
v_img = self.v_img(img)
y = flash_attention(q, k_img, v_img, num_heads=self.num_heads)
x = x + y
return self.o(x)
class GateModule(nn.Module):
def __init__(self,):
super().__init__()
def forward(self, x, gate, residual):
return x + gate * residual
class DiTBlock(nn.Module):
def __init__(self, has_image_input: bool, dim: int, num_heads: int, ffn_dim: int, eps: float = 1e-6):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.ffn_dim = ffn_dim
self.self_attn = SelfAttention(dim, num_heads, eps)
self.cross_attn = CrossAttention(
dim, num_heads, eps, has_image_input=has_image_input)
self.norm1 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False)
self.norm2 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False)
self.norm3 = nn.LayerNorm(dim, eps=eps)
self.ffn = nn.Sequential(nn.Linear(dim, ffn_dim), nn.GELU(
approximate='tanh'), nn.Linear(ffn_dim, dim))
self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
self.gate = GateModule()
def forward(self, x, context, t_mod, freqs, attn_mask=None, shot_latent_indices=None, per_g=0, ID_2_shot=None):
has_seq = len(t_mod.shape) == 4
chunk_dim = 2 if has_seq else 1
# msa: multi-head self-attention mlp: multi-layer perceptron
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(6, dim=chunk_dim)
if has_seq:
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
shift_msa.squeeze(2), scale_msa.squeeze(2), gate_msa.squeeze(2),
shift_mlp.squeeze(2), scale_mlp.squeeze(2), gate_mlp.squeeze(2),
)
input_x = modulate(self.norm1(x), shift_msa, scale_msa)
x = self.gate(x, gate_msa, self.self_attn(input_x, freqs, shot_latent_indices=shot_latent_indices, per_g = per_g, ID_2_shot = ID_2_shot))
x = x + self.cross_attn(self.norm3(x), context, attn_mask)
input_x = modulate(self.norm2(x), shift_mlp, scale_mlp)
x = self.gate(x, gate_mlp, self.ffn(input_x))
return x
class MLP(torch.nn.Module):
def __init__(self, in_dim, out_dim, has_pos_emb=False):
super().__init__()
self.proj = torch.nn.Sequential(
nn.LayerNorm(in_dim),
nn.Linear(in_dim, in_dim),
nn.GELU(),
nn.Linear(in_dim, out_dim),
nn.LayerNorm(out_dim)
)
self.has_pos_emb = has_pos_emb
if has_pos_emb:
self.emb_pos = torch.nn.Parameter(torch.zeros((1, 514, 1280)))
def forward(self, x):
if self.has_pos_emb:
x = x + self.emb_pos.to(dtype=x.dtype, device=x.device)
return self.proj(x)
class Head(nn.Module):
def __init__(self, dim: int, out_dim: int, patch_size: Tuple[int, int, int], eps: float):
super().__init__()
self.dim = dim
self.patch_size = patch_size
self.norm = nn.LayerNorm(dim, eps=eps, elementwise_affine=False)
self.head = nn.Linear(dim, out_dim * math.prod(patch_size))
self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)
def forward(self, x, t_mod):
if len(t_mod.shape) == 3:
shift, scale = (self.modulation.unsqueeze(0).to(dtype=t_mod.dtype, device=t_mod.device) + t_mod.unsqueeze(2)).chunk(2, dim=2)
x = (self.head(self.norm(x) * (1 + scale.squeeze(2)) + shift.squeeze(2)))
else:
shift, scale = (self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(2, dim=1)
x = (self.head(self.norm(x) * (1 + scale) + shift))
return x
class WanModel(torch.nn.Module):
def __init__(
self,
dim: int,
in_dim: int,
ffn_dim: int,
out_dim: int,
text_dim: int,
freq_dim: int,
eps: float,
patch_size: Tuple[int, int, int],
num_heads: int,
num_layers: int,
has_image_input: bool,
has_image_pos_emb: bool = False,
has_ref_conv: bool = False,
add_control_adapter: bool = False,
in_dim_control_adapter: int = 24,
seperated_timestep: bool = False,
require_vae_embedding: bool = True,
require_clip_embedding: bool = True,
fuse_vae_embedding_in_latents: bool = False,
):
super().__init__()
self.dim = dim
self.in_dim = in_dim
self.freq_dim = freq_dim
self.has_image_input = has_image_input
self.patch_size = patch_size
self.seperated_timestep = seperated_timestep
self.require_vae_embedding = require_vae_embedding
self.require_clip_embedding = require_clip_embedding
self.fuse_vae_embedding_in_latents = fuse_vae_embedding_in_latents
self.patch_embedding = nn.Conv3d(
in_dim, dim, kernel_size=patch_size, stride=patch_size)
self.text_embedding = nn.Sequential(
nn.Linear(text_dim, dim),
nn.GELU(approximate='tanh'),
nn.Linear(dim, dim)
)
self.time_embedding = nn.Sequential(
nn.Linear(freq_dim, dim),
nn.SiLU(),
nn.Linear(dim, dim)
)
self.time_projection = nn.Sequential(
nn.SiLU(), nn.Linear(dim, dim * 6))
self.blocks = nn.ModuleList([
DiTBlock(has_image_input, dim, num_heads, ffn_dim, eps)
for _ in range(num_layers)
])
self.head = Head(dim, out_dim, patch_size, eps)
head_dim = dim // num_heads
self.freqs = precompute_freqs_cis_3d(head_dim)
self.shot_freqs = precompute_freqs_cis_4d(head_dim)
if has_image_input:
self.img_emb = MLP(1280, dim, has_pos_emb=has_image_pos_emb) # clip_feature_dim = 1280
if has_ref_conv:
self.ref_conv = nn.Conv2d(16, dim, kernel_size=(2, 2), stride=(2, 2))
self.has_image_pos_emb = has_image_pos_emb
self.has_ref_conv = has_ref_conv
if add_control_adapter:
self.control_adapter = SimpleAdapter(in_dim_control_adapter, dim, kernel_size=patch_size[1:], stride=patch_size[1:])
else:
self.control_adapter = None
def patchify(self, x: torch.Tensor, control_camera_latents_input: Optional[torch.Tensor] = None):
x = self.patch_embedding(x)
if self.control_adapter is not None and control_camera_latents_input is not None:
y_camera = self.control_adapter(control_camera_latents_input)
x = [u + v for u, v in zip(x, y_camera)]
x = x[0].unsqueeze(0)
grid_size = x.shape[2:]
x = rearrange(x, 'b c f h w -> b (f h w) c').contiguous()
return x, grid_size # x, grid_size: (f, h, w)
def unpatchify(self, x: torch.Tensor, grid_size: torch.Tensor):
return rearrange(
x, 'b (f h w) (x y z c) -> b c (f x) (h y) (w z)',
f=grid_size[0], h=grid_size[1], w=grid_size[2],
x=self.patch_size[0], y=self.patch_size[1], z=self.patch_size[2]
)
def forward(self,
x: torch.Tensor,
timestep: torch.Tensor,
context: torch.Tensor,
clip_feature: Optional[torch.Tensor] = None,
y: Optional[torch.Tensor] = None,
use_gradient_checkpointing: bool = False,
use_gradient_checkpointing_offload: bool = False,
**kwargs,
):
t = self.time_embedding(
sinusoidal_embedding_1d(self.freq_dim, timestep))
t_mod = self.time_projection(t).unflatten(1, (6, self.dim))
context = self.text_embedding(context)
if self.has_image_input:
x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w)
clip_embdding = self.img_emb(clip_feature)
context = torch.cat([clip_embdding, context], dim=1)
x, (f, h, w) = self.patchify(x)
freqs = torch.cat([
self.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
self.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
for block in self.blocks:
if self.training and use_gradient_checkpointing:
if use_gradient_checkpointing_offload:
with torch.autograd.graph.save_on_cpu():
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x, context, t_mod, freqs,
use_reentrant=False,
)
else:
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x, context, t_mod, freqs,
use_reentrant=False,
)
else:
x = block(x, context, t_mod, freqs)
x = self.head(x, t)
x = self.unpatchify(x, (f, h, w))
return x
@staticmethod
def state_dict_converter():
return WanModelStateDictConverter()
class WanModelStateDictConverter:
def __init__(self):
pass
def from_diffusers(self, state_dict):
rename_dict = {
"blocks.0.attn1.norm_k.weight": "blocks.0.self_attn.norm_k.weight",
"blocks.0.attn1.norm_q.weight": "blocks.0.self_attn.norm_q.weight",
"blocks.0.attn1.to_k.bias": "blocks.0.self_attn.k.bias",
"blocks.0.attn1.to_k.weight": "blocks.0.self_attn.k.weight",
"blocks.0.attn1.to_out.0.bias": "blocks.0.self_attn.o.bias",
"blocks.0.attn1.to_out.0.weight": "blocks.0.self_attn.o.weight",
"blocks.0.attn1.to_q.bias": "blocks.0.self_attn.q.bias",
"blocks.0.attn1.to_q.weight": "blocks.0.self_attn.q.weight",
"blocks.0.attn1.to_v.bias": "blocks.0.self_attn.v.bias",
"blocks.0.attn1.to_v.weight": "blocks.0.self_attn.v.weight",
"blocks.0.attn2.norm_k.weight": "blocks.0.cross_attn.norm_k.weight",
"blocks.0.attn2.norm_q.weight": "blocks.0.cross_attn.norm_q.weight",
"blocks.0.attn2.to_k.bias": "blocks.0.cross_attn.k.bias",
"blocks.0.attn2.to_k.weight": "blocks.0.cross_attn.k.weight",
"blocks.0.attn2.to_out.0.bias": "blocks.0.cross_attn.o.bias",
"blocks.0.attn2.to_out.0.weight": "blocks.0.cross_attn.o.weight",
"blocks.0.attn2.to_q.bias": "blocks.0.cross_attn.q.bias",
"blocks.0.attn2.to_q.weight": "blocks.0.cross_attn.q.weight",
"blocks.0.attn2.to_v.bias": "blocks.0.cross_attn.v.bias",
"blocks.0.attn2.to_v.weight": "blocks.0.cross_attn.v.weight",
"blocks.0.ffn.net.0.proj.bias": "blocks.0.ffn.0.bias",
"blocks.0.ffn.net.0.proj.weight": "blocks.0.ffn.0.weight",
"blocks.0.ffn.net.2.bias": "blocks.0.ffn.2.bias",
"blocks.0.ffn.net.2.weight": "blocks.0.ffn.2.weight",
"blocks.0.norm2.bias": "blocks.0.norm3.bias",
"blocks.0.norm2.weight": "blocks.0.norm3.weight",
"blocks.0.scale_shift_table": "blocks.0.modulation",
"condition_embedder.text_embedder.linear_1.bias": "text_embedding.0.bias",
"condition_embedder.text_embedder.linear_1.weight": "text_embedding.0.weight",
"condition_embedder.text_embedder.linear_2.bias": "text_embedding.2.bias",
"condition_embedder.text_embedder.linear_2.weight": "text_embedding.2.weight",
"condition_embedder.time_embedder.linear_1.bias": "time_embedding.0.bias",
"condition_embedder.time_embedder.linear_1.weight": "time_embedding.0.weight",
"condition_embedder.time_embedder.linear_2.bias": "time_embedding.2.bias",
"condition_embedder.time_embedder.linear_2.weight": "time_embedding.2.weight",
"condition_embedder.time_proj.bias": "time_projection.1.bias",
"condition_embedder.time_proj.weight": "time_projection.1.weight",
"patch_embedding.bias": "patch_embedding.bias",
"patch_embedding.weight": "patch_embedding.weight",
"scale_shift_table": "head.modulation",
"proj_out.bias": "head.head.bias",
"proj_out.weight": "head.head.weight",
}
state_dict_ = {}
for name, param in state_dict.items():
if name in rename_dict:
state_dict_[rename_dict[name]] = param
else:
name_ = ".".join(name.split(".")[:1] + ["0"] + name.split(".")[2:])
if name_ in rename_dict:
name_ = rename_dict[name_]
name_ = ".".join(name_.split(".")[:1] + [name.split(".")[1]] + name_.split(".")[2:])
state_dict_[name_] = param
if hash_state_dict_keys(state_dict) == "cb104773c6c2cb6df4f9529ad5c60d0b":
config = {
"model_type": "t2v",
"patch_size": (1, 2, 2),
"text_len": 512,
"in_dim": 16,
"dim": 5120,
"ffn_dim": 13824,
"freq_dim": 256,
"text_dim": 4096,
"out_dim": 16,
"num_heads": 40,
"num_layers": 40,
"window_size": (-1, -1),
"qk_norm": True,
"cross_attn_norm": True,
"eps": 1e-6,
}
else:
config = {}
return state_dict_, config
def from_civitai(self, state_dict):
state_dict = {name: param for name, param in state_dict.items() if not name.startswith("vace")}
if hash_state_dict_keys(state_dict) == "9269f8db9040a9d860eaca435be61814":
config = {
"has_image_input": False,
"patch_size": [1, 2, 2],
"in_dim": 16,
"dim": 1536,
"ffn_dim": 8960,
"freq_dim": 256,
"text_dim": 4096,
"out_dim": 16,
"num_heads": 12,
"num_layers": 30,
"eps": 1e-6
}
elif hash_state_dict_keys(state_dict) == "aafcfd9672c3a2456dc46e1cb6e52c70":
config = {
"has_image_input": False,
"patch_size": [1, 2, 2],
"in_dim": 16,
"dim": 5120,
"ffn_dim": 13824,
"freq_dim": 256,
"text_dim": 4096,
"out_dim": 16,
"num_heads": 40,
"num_layers": 40,
"eps": 1e-6
}
elif hash_state_dict_keys(state_dict) == "6bfcfb3b342cb286ce886889d519a77e":
config = {
"has_image_input": True,
"patch_size": [1, 2, 2],
"in_dim": 36,
"dim": 5120,
"ffn_dim": 13824,
"freq_dim": 256,
"text_dim": 4096,
"out_dim": 16,
"num_heads": 40,
"num_layers": 40,
"eps": 1e-6
}
elif hash_state_dict_keys(state_dict) == "6d6ccde6845b95ad9114ab993d917893":
config = {
"has_image_input": True,
"patch_size": [1, 2, 2],
"in_dim": 36,
"dim": 1536,
"ffn_dim": 8960,
"freq_dim": 256,
"text_dim": 4096,
"out_dim": 16,
"num_heads": 12,
"num_layers": 30,
"eps": 1e-6
}
elif hash_state_dict_keys(state_dict) == "6bfcfb3b342cb286ce886889d519a77e":
config = {
"has_image_input": True,
"patch_size": [1, 2, 2],
"in_dim": 36,
"dim": 5120,
"ffn_dim": 13824,
"freq_dim": 256,
"text_dim": 4096,
"out_dim": 16,
"num_heads": 40,
"num_layers": 40,
"eps": 1e-6
}
elif hash_state_dict_keys(state_dict) == "349723183fc063b2bfc10bb2835cf677":
# 1.3B PAI control
config = {
"has_image_input": True,
"patch_size": [1, 2, 2],
"in_dim": 48,
"dim": 1536,
"ffn_dim": 8960,
"freq_dim": 256,
"text_dim": 4096,
"out_dim": 16,
"num_heads": 12,
"num_layers": 30,
"eps": 1e-6
}
elif hash_state_dict_keys(state_dict) == "efa44cddf936c70abd0ea28b6cbe946c":
# 14B PAI control
config = {
"has_image_input": True,
"patch_size": [1, 2, 2],
"in_dim": 48,
"dim": 5120,
"ffn_dim": 13824,
"freq_dim": 256,
"text_dim": 4096,
"out_dim": 16,
"num_heads": 40,
"num_layers": 40,
"eps": 1e-6
}
elif hash_state_dict_keys(state_dict) == "3ef3b1f8e1dab83d5b71fd7b617f859f":
config = {
"has_image_input": True,
"patch_size": [1, 2, 2],
"in_dim": 36,
"dim": 5120,
"ffn_dim": 13824,
"freq_dim": 256,
"text_dim": 4096,
"out_dim": 16,
"num_heads": 40,
"num_layers": 40,
"eps": 1e-6,
"has_image_pos_emb": True
}
elif hash_state_dict_keys(state_dict) == "70ddad9d3a133785da5ea371aae09504":
# 1.3B PAI control v1.1
config = {
"has_image_input": True,
"patch_size": [1, 2, 2],
"in_dim": 48,
"dim": 1536,
"ffn_dim": 8960,
"freq_dim": 256,
"text_dim": 4096,
"out_dim": 16,
"num_heads": 12,
"num_layers": 30,
"eps": 1e-6,
"has_ref_conv": True
}
elif hash_state_dict_keys(state_dict) == "26bde73488a92e64cc20b0a7485b9e5b":
# 14B PAI control v1.1
config = {
"has_image_input": True,
"patch_size": [1, 2, 2],
"in_dim": 48,
"dim": 5120,
"ffn_dim": 13824,
"freq_dim": 256,
"text_dim": 4096,
"out_dim": 16,
"num_heads": 40,
"num_layers": 40,
"eps": 1e-6,
"has_ref_conv": True
}
elif hash_state_dict_keys(state_dict) == "ac6a5aa74f4a0aab6f64eb9a72f19901":
# 1.3B PAI control-camera v1.1
config = {
"has_image_input": True,
"patch_size": [1, 2, 2],
"in_dim": 32,
"dim": 1536,
"ffn_dim": 8960,
"freq_dim": 256,
"text_dim": 4096,
"out_dim": 16,
"num_heads": 12,
"num_layers": 30,
"eps": 1e-6,
"has_ref_conv": False,
"add_control_adapter": True,
"in_dim_control_adapter": 24,
}
elif hash_state_dict_keys(state_dict) == "b61c605c2adbd23124d152ed28e049ae":
# 14B PAI control-camera v1.1
config = {
"has_image_input": True,
"patch_size": [1, 2, 2],
"in_dim": 32,
"dim": 5120,
"ffn_dim": 13824,
"freq_dim": 256,
"text_dim": 4096,
"out_dim": 16,
"num_heads": 40,
"num_layers": 40,
"eps": 1e-6,
"has_ref_conv": False,
"add_control_adapter": True,
"in_dim_control_adapter": 24,
}
elif hash_state_dict_keys(state_dict) == "1f5ab7703c6fc803fdded85ff040c316":
# Wan-AI/Wan2.2-TI2V-5B
config = {
"has_image_input": False,
"patch_size": [1, 2, 2],
"in_dim": 48,
"dim": 3072,
"ffn_dim": 14336,
"freq_dim": 256,
"text_dim": 4096,
"out_dim": 48,
"num_heads": 24,
"num_layers": 30,
"eps": 1e-6,
"seperated_timestep": True,
"require_clip_embedding": False,
"require_vae_embedding": False,
"fuse_vae_embedding_in_latents": True,
}
elif hash_state_dict_keys(state_dict) == "5b013604280dd715f8457c6ed6d6a626":
# Wan-AI/Wan2.2-I2V-A14B
config = {
"has_image_input": False,
"patch_size": [1, 2, 2],
"in_dim": 36,
"dim": 5120,
"ffn_dim": 13824,
"freq_dim": 256,
"text_dim": 4096,
"out_dim": 16,
"num_heads": 40,
"num_layers": 40,
"eps": 1e-6,
"require_clip_embedding": False,
}
elif hash_state_dict_keys(state_dict) == "2267d489f0ceb9f21836532952852ee5":
# Wan2.2-Fun-A14B-Control
config = {
"has_image_input": False,
"patch_size": [1, 2, 2],
"in_dim": 52,
"dim": 5120,
"ffn_dim": 13824,
"freq_dim": 256,
"text_dim": 4096,
"out_dim": 16,
"num_heads": 40,
"num_layers": 40,
"eps": 1e-6,
"has_ref_conv": True,
"require_clip_embedding": False,
}
elif hash_state_dict_keys(state_dict) == "47dbeab5e560db3180adf51dc0232fb1":
# Wan2.2-Fun-A14B-Control-Camera
config = {
"has_image_input": False,
"patch_size": [1, 2, 2],
"in_dim": 36,
"dim": 5120,
"ffn_dim": 13824,
"freq_dim": 256,
"text_dim": 4096,
"out_dim": 16,
"num_heads": 40,
"num_layers": 40,
"eps": 1e-6,
"has_ref_conv": False,
"add_control_adapter": True,
"in_dim_control_adapter": 24,
"require_clip_embedding": False,
}
else:
config = {}
return state_dict, config