| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import math |
| from typing import Tuple, Optional, List, Dict, Sequence |
| from einops import rearrange |
| from .utils import hash_state_dict_keys |
| from .wan_video_camera_controller import SimpleAdapter |
| try: |
| import flash_attn_interface |
| FLASH_ATTN_3_AVAILABLE = True |
| except ModuleNotFoundError: |
| FLASH_ATTN_3_AVAILABLE = False |
|
|
| try: |
| import flash_attn |
| FLASH_ATTN_2_AVAILABLE = True |
| except ModuleNotFoundError: |
| FLASH_ATTN_2_AVAILABLE = False |
|
|
| try: |
| from sageattention import sageattn |
| SAGE_ATTN_AVAILABLE = True |
| except ModuleNotFoundError: |
| SAGE_ATTN_AVAILABLE = False |
|
|
| print("FLASH_ATTN_3_AVAILABLE ",FLASH_ATTN_3_AVAILABLE) |
| print("FLASH_ATTN_2_AVAILABLE",FLASH_ATTN_2_AVAILABLE) |
| try: |
| from flash_attn_interface import flash_attn_varlen_func |
| except: |
| try: |
| from flash_attn.flash_attn_interface import flash_attn_varlen_func |
| except Exception as e: |
| flash_attn_varlen_func = None |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| def flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads: int, compatibility_mode=False, attn_mask=None, shot_latent_indices=None): |
|
|
| if attn_mask is not None: |
|
|
| q = rearrange(q, "b s (n d) -> b n s d", n=num_heads) |
| k = rearrange(k, "b s (n d) -> b n s d", n=num_heads) |
| v = rearrange(v, "b s (n d) -> b n s d", n=num_heads) |
| x = F.scaled_dot_product_attention(q, k, v, attn_mask = attn_mask) |
| x = rearrange(x, "b n s d -> b s (n d)", n=num_heads) |
| else: |
| if shot_latent_indices is not None: |
| q = rearrange(q, "b s (n d) -> b n s d", n=num_heads) |
| k = rearrange(k, "b s (n d) -> b n s d", n=num_heads) |
| v = rearrange(v, "b s (n d) -> b n s d", n=num_heads) |
| |
| x = rearrange(x, "b n s d -> b s (n d)", n=num_heads) |
| elif compatibility_mode: |
| q = rearrange(q, "b s (n d) -> b n s d", n=num_heads) |
| k = rearrange(k, "b s (n d) -> b n s d", n=num_heads) |
| v = rearrange(v, "b s (n d) -> b n s d", n=num_heads) |
| x = F.scaled_dot_product_attention(q, k, v) |
| x = rearrange(x, "b n s d -> b s (n d)", n=num_heads) |
| elif FLASH_ATTN_3_AVAILABLE: |
| q = rearrange(q, "b s (n d) -> b s n d", n=num_heads) |
| k = rearrange(k, "b s (n d) -> b s n d", n=num_heads) |
| v = rearrange(v, "b s (n d) -> b s n d", n=num_heads) |
| x = flash_attn_interface.flash_attn_func(q, k, v) |
| if isinstance(x,tuple): |
| x = x[0] |
| x = rearrange(x, "b s n d -> b s (n d)", n=num_heads) |
| elif FLASH_ATTN_2_AVAILABLE: |
| |
| q = rearrange(q, "b s (n d) -> b s n d", n=num_heads) |
| k = rearrange(k, "b s (n d) -> b s n d", n=num_heads) |
| v = rearrange(v, "b s (n d) -> b s n d", n=num_heads) |
| x = flash_attn.flash_attn_func(q, k, v) |
| x = rearrange(x, "b s n d -> b s (n d)", n=num_heads) |
| elif SAGE_ATTN_AVAILABLE: |
| q = rearrange(q, "b s (n d) -> b n s d", n=num_heads) |
| k = rearrange(k, "b s (n d) -> b n s d", n=num_heads) |
| v = rearrange(v, "b s (n d) -> b n s d", n=num_heads) |
| x = sageattn(q, k, v) |
| x = rearrange(x, "b n s d -> b s (n d)", n=num_heads) |
| else: |
| q = rearrange(q, "b s (n d) -> b n s d", n=num_heads) |
| k = rearrange(k, "b s (n d) -> b n s d", n=num_heads) |
| v = rearrange(v, "b s (n d) -> b n s d", n=num_heads) |
|
|
| x = F.scaled_dot_product_attention(q, k, v) |
| x = rearrange(x, "b n s d -> b s (n d)", n=num_heads) |
| return x |
|
|
|
|
| def build_global_reps_from_shots( |
| K_local_shots: List[torch.Tensor], |
| V_local_shots: List[torch.Tensor], |
| g_per: int, |
| mode: str = "firstk" |
| ): |
| """ |
| 简单的代表池构造:从每个 shot 的本地 K/V 生成若干代表 token,并拼成共享池。 |
| K_local_shots[i]: [Ni, H, D] |
| 返回: |
| K_global: [G_total, H, D], V_global: [G_total, H, D] |
| """ |
| reps_k, reps_v = [], [] |
| S = len(K_local_shots) |
| if S == 0: |
| return (torch.empty(0), torch.empty(0)) |
|
|
| |
| G = g_per * S |
| |
| for Ki, Vi in zip(K_local_shots, V_local_shots): |
| Ni = Ki.size(0) |
| if Ni == 0 or g_per == 0: |
| continue |
| if mode == "mean": |
| idx = torch.linspace(0, Ni - 1, steps=g_per, device=Ki.device).long() |
| reps_k.append(Ki.index_select(0, idx)) |
| reps_v.append(Vi.index_select(0, idx)) |
| elif mode == "firstk": |
| take = min(g_per, Ni) |
| reps_k.append(Ki[:take]) |
| reps_v.append(Vi[:take]) |
| elif mode == "linspace": |
| idx = torch.linspace(0, Ni - 1, steps=g_per, device=Ki.device).long() |
| reps_k.append(Ki.index_select(0, idx)) |
| reps_v.append(Vi.index_select(0, idx)) |
| else: |
| raise ValueError(f"unknown mode {mode}") |
| if len(reps_k) == 0: |
| return (torch.empty(0, *K_local_shots[0].shape[1:], device=K_local_shots[0].device, dtype=K_local_shots[0].dtype), |
| torch.empty(0, *V_local_shots[0].shape[1:], device=V_local_shots[0].device, dtype=V_local_shots[0].dtype)) |
| K_global = torch.cat(reps_k, dim=0) |
| V_global = torch.cat(reps_v, dim=0) |
|
|
| return K_global, V_global |
|
|
| def build_ID_reps( |
| IDs_2_shots: Dict[int, List[int]], |
| K_shots: List[torch.Tensor], |
| V_shots: List[torch.Tensor], |
| ): |
| """ |
| shot_2_IDs: |
| { |
| shot_id: [id_shot_id_1, id_shot_id_2, ...] # ✅ 这里的 ID 是“特殊shot”的下标 |
| } |
| |
| Returns: |
| shot_id -> {"K": K_id, "V": V_id} |
| 其中 K_id/V_id 是该 shot 关联的所有 ID-shot 的 token 拼起来的结果: |
| K_id: [sum(N_id), H, D] |
| V_id: [sum(N_id), H, D] |
| """ |
| shot_id_kv = {} |
|
|
| for shot_id, id_shot_ids in shot_2_IDs.items(): |
| reps_k, reps_v = [], [] |
|
|
| for id_sid in id_shot_ids: |
| if id_sid < 0 or id_sid >= len(K_shots): |
| continue |
|
|
| Ki = K_shots[id_sid] |
| Vi = V_shots[id_sid] |
|
|
| |
| if Ki is None or Vi is None or Ki.numel() == 0: |
| continue |
|
|
| reps_k.append(Ki) |
| reps_v.append(Vi) |
|
|
| if len(reps_k) == 0: |
| |
| |
| device = K_shots[0].device |
| dtype = K_shots[0].dtype |
| shot_id_kv[shot_id] = { |
| "K": torch.empty(0, *K_shots[0].shape[1:], device=device, dtype=dtype), |
| "V": torch.empty(0, *V_shots[0].shape[1:], device=device, dtype=dtype), |
| } |
| continue |
|
|
| shot_id_kv[shot_id] = { |
| "K": torch.cat(reps_k, dim=0), |
| "V": torch.cat(reps_v, dim=0), |
| } |
|
|
| return shot_id_kv |
|
|
| def attention_per_batch_with_shots( |
| q: torch.Tensor, |
| k: torch.Tensor, |
| v: torch.Tensor, |
| shot_latent_indices: Sequence[Sequence[int]], |
| num_heads: int, |
| |
| per_g: int = 64, |
| ID_2_shot=None, |
| |
| dropout_p: float = 0.0, |
| causal: bool = False |
| ): |
| assert q.shape == k.shape == v.shape, "shape wrong in attention_per_batch_with_shots" |
| b, s_tot, hd = q.shape |
| assert hd % num_heads == 0 |
| d = hd // num_heads |
| dtype = q.dtype |
| device = q.device |
| |
| q = rearrange(q, "b s (n d) -> b n s d", n=num_heads).contiguous() |
| k = rearrange(k, "b s (n d) -> b n s d", n=num_heads).contiguous() |
| v = rearrange(v, "b s (n d) -> b n s d", n=num_heads).contiguous() |
| outputs = [] |
| if flash_attn_varlen_func is None: |
| raise RuntimeError("flash_attn_varlen_func not available. Please install flash-attn v2+.") |
| for bi in range(b): |
| cuts = list(shot_latent_indices[bi]) |
| assert cuts[0] == 0 and cuts[-1] == s_tot, "shot_latent_indices must start with 0 and end with s_tot" |
| Q_shots, K_shots, V_shots = [], [], [] |
| N_list = [] |
| for a, bnd in zip(cuts[:-1], cuts[1:]): |
| Q_shots.append(q[bi, :, a:bnd, :]) |
| K_shots.append(k[bi, :, a:bnd, :]) |
| V_shots.append(v[bi, :, a:bnd, :]) |
| N_list.append(bnd - a) |
| Q_locals = [rearrange(Qi, "n s d -> s n d") for Qi in Q_shots] |
| K_locals = [rearrange(Ki, "n s d -> s n d") for Ki in K_shots] |
| V_locals = [rearrange(Vi, "n s d -> s n d") for Vi in V_shots] |
|
|
| K_list = [] |
| V_list = [] |
| kv_lengths = [] |
| ids_for_batch = None |
| if ID_2_shot is not None and bi < len(ID_2_shot): |
| ids_for_batch = ID_2_shot[bi] |
| if ids_for_batch: |
| pre_id_token_num = per_g * 3 |
| shot_token_all_num = cuts[-1] |
| for shot_id in range(len(K_locals)): |
| id_list = ids_for_batch[shot_id] if shot_id < len(ids_for_batch) else [] |
| extra_k = [] |
| extra_v = [] |
| for id_idx in id_list: |
| start = shot_token_all_num + id_idx * pre_id_token_num |
| if start >= k.shape[2]: |
| continue |
| end = min(start + pre_id_token_num, k.shape[2]) |
| id_token_k = k[bi, :, start:end, :] |
| id_token_v = v[bi, :, start:end, :] |
| id_token_k = rearrange(id_token_k, "n s d -> s n d") |
| id_token_v = rearrange(id_token_v, "n s d -> s n d") |
| extra_k.append(id_token_k) |
| extra_v.append(id_token_v) |
| if extra_k: |
| extra_k = torch.cat(extra_k, dim=0) |
| extra_v = torch.cat(extra_v, dim=0) |
| K_list.append(torch.cat([K_locals[shot_id], extra_k], dim=0)) |
| V_list.append(torch.cat([V_locals[shot_id], extra_v], dim=0)) |
| kv_lengths.append(N_list[shot_id] + extra_k.shape[0]) |
| else: |
| K_list.append(K_locals[shot_id]) |
| V_list.append(V_locals[shot_id]) |
| kv_lengths.append(N_list[shot_id]) |
| else: |
| K_list = K_locals |
| V_list = V_locals |
| kv_lengths = N_list |
|
|
| Q_packed = torch.cat(Q_locals, dim=0) |
| K_packed = torch.cat(K_list, dim=0) |
| V_packed = torch.cat(V_list, dim=0) |
| Sshots = len(N_list) |
| q_seqlens = torch.tensor([0] + [sum(N_list[:i+1]) for i in range(Sshots)], |
| device=device, dtype=torch.int32) |
| kv_seqlens = torch.tensor([0] + [sum(kv_lengths[:i+1]) for i in range(Sshots)], |
| device=device, dtype=torch.int32) |
| max_q_seqlen = max(N_list) if len(N_list) > 0 else 0 |
| max_kv_seqlen = max(kv_lengths) if len(kv_lengths) > 0 else 0 |
| O_packed = flash_attn_varlen_func( |
| Q_packed, K_packed, V_packed, |
| q_seqlens, kv_seqlens, |
| max_q_seqlen, max_kv_seqlen, |
| softmax_scale=None, causal=causal |
| ) |
| O_list = [] |
| for i in range(Sshots): |
| st = q_seqlens[i].item() |
| ed = q_seqlens[i+1].item() |
| Oi = O_packed[st:ed] |
| O_list.append(Oi) |
| O_local = torch.cat(O_list, dim=0) |
| O_local = rearrange(O_local, "s n d -> n s d").contiguous() |
| outputs.append(O_local) |
| x = torch.stack(outputs, dim=0) |
| x = rearrange(x, "b n s d -> b s (n d)") |
| return x |
|
|
| def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor): |
| return (x * (1 + scale) + shift) |
|
|
|
|
| def sinusoidal_embedding_1d(dim, position): |
| sinusoid = torch.outer(position.type(torch.float64), torch.pow( |
| 10000, -torch.arange(dim//2, dtype=torch.float64, device=position.device).div(dim//2))) |
| x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1) |
| return x.to(position.dtype) |
|
|
| def precompute_freqs_cis_4d(dim: int, end: int = 1024, theta: float = 10000.0): |
| |
|
|
| s_freqs_cis = precompute_freqs_cis(dim - 3 * (dim // 4), end, theta) |
| f_freqs_cis = precompute_freqs_cis(dim // 4, end, theta) |
| h_freqs_cis = precompute_freqs_cis(dim // 4, end, theta) |
| w_freqs_cis = precompute_freqs_cis(dim // 4, end, theta) |
| return s_freqs_cis, f_freqs_cis, h_freqs_cis, w_freqs_cis |
|
|
| def precompute_freqs_cis_3d(dim: int, end: int = 1024, theta: float = 10000.0): |
| |
| f_freqs_cis = precompute_freqs_cis(dim - 2 * (dim // 3), end, theta) |
| h_freqs_cis = precompute_freqs_cis(dim // 3, end, theta) |
| w_freqs_cis = precompute_freqs_cis(dim // 3, end, theta) |
| return f_freqs_cis, h_freqs_cis, w_freqs_cis |
|
|
|
|
| def precompute_freqs_cis(dim: int, end: int = 1024, theta: float = 10000.0): |
| |
| freqs = 1.0 / (theta ** (torch.arange(0, dim, 2) |
| [: (dim // 2)].double() / dim)) |
| freqs = torch.outer(torch.arange(end, device=freqs.device), freqs) |
| freqs_cis = torch.polar(torch.ones_like(freqs), freqs) |
| return freqs_cis |
|
|
|
|
| def rope_apply(x, freqs, num_heads): |
| x = rearrange(x, "b s (n d) -> b s n d", n=num_heads) |
| x_out = torch.view_as_complex(x.to(torch.float64).reshape( |
| x.shape[0], x.shape[1], x.shape[2], -1, 2)) |
| x_out = torch.view_as_real(x_out * freqs).flatten(2) |
| return x_out.to(x.dtype) |
|
|
|
|
| class RMSNorm(nn.Module): |
| def __init__(self, dim, eps=1e-5): |
| super().__init__() |
| self.eps = eps |
| self.weight = nn.Parameter(torch.ones(dim)) |
|
|
| def norm(self, x): |
| return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps) |
|
|
| def forward(self, x): |
| dtype = x.dtype |
| return self.norm(x.float()).to(dtype) * self.weight |
|
|
| class AttentionModule(nn.Module): |
| def __init__(self, num_heads): |
| super().__init__() |
| self.num_heads = num_heads |
| |
| def forward(self, q, k, v, attn_mask=None, shot_latent_indices = None, per_g=0, ID_2_shot=None): |
| if attn_mask is not None: |
| x = flash_attention(q=q, k=k, v=v, num_heads=self.num_heads, attn_mask=attn_mask) |
| elif shot_latent_indices is not None: |
| x = attention_per_batch_with_shots(q=q, k=k, v=v, shot_latent_indices=shot_latent_indices, num_heads=self.num_heads, per_g=per_g, ID_2_shot=ID_2_shot) |
| else: |
| x = flash_attention(q=q, k=k, v=v, num_heads=self.num_heads) |
| return x |
|
|
| class SelfAttention(nn.Module): |
| def __init__(self, dim: int, num_heads: int, eps: float = 1e-6): |
| super().__init__() |
| self.dim = dim |
| self.num_heads = num_heads |
| self.head_dim = dim // num_heads |
|
|
| self.q = nn.Linear(dim, dim) |
| self.k = nn.Linear(dim, dim) |
| self.v = nn.Linear(dim, dim) |
| self.o = nn.Linear(dim, dim) |
| self.norm_q = RMSNorm(dim, eps=eps) |
| self.norm_k = RMSNorm(dim, eps=eps) |
| |
| self.attn = AttentionModule(self.num_heads) |
|
|
| def forward(self, x, freqs, shot_latent_indices=None, per_g=0, ID_2_shot=None): |
| q = self.norm_q(self.q(x)) |
| k = self.norm_k(self.k(x)) |
| v = self.v(x) |
| q = rope_apply(q, freqs, self.num_heads) |
| k = rope_apply(k, freqs, self.num_heads) |
| x = self.attn(q, k, v, shot_latent_indices=shot_latent_indices, per_g=per_g, ID_2_shot=ID_2_shot) |
| return self.o(x) |
|
|
|
|
| class CrossAttention(nn.Module): |
| def __init__(self, dim: int, num_heads: int, eps: float = 1e-6, has_image_input: bool = False): |
| super().__init__() |
| self.dim = dim |
| self.num_heads = num_heads |
| self.head_dim = dim // num_heads |
|
|
| self.q = nn.Linear(dim, dim) |
| self.k = nn.Linear(dim, dim) |
| self.v = nn.Linear(dim, dim) |
| self.o = nn.Linear(dim, dim) |
| self.norm_q = RMSNorm(dim, eps=eps) |
| self.norm_k = RMSNorm(dim, eps=eps) |
| self.has_image_input = has_image_input |
| if has_image_input: |
| self.k_img = nn.Linear(dim, dim) |
| self.v_img = nn.Linear(dim, dim) |
| self.norm_k_img = RMSNorm(dim, eps=eps) |
| |
| self.attn = AttentionModule(self.num_heads) |
|
|
| def forward(self, x: torch.Tensor, y: torch.Tensor, attn_mask=None): |
| if self.has_image_input: |
| img = y[:, :257] |
| ctx = y[:, 257:] |
| else: |
| ctx = y |
| q = self.norm_q(self.q(x)) |
| k = self.norm_k(self.k(ctx)) |
| v = self.v(ctx) |
| x = self.attn(q, k, v, attn_mask=attn_mask) |
| if self.has_image_input: |
| k_img = self.norm_k_img(self.k_img(img)) |
| v_img = self.v_img(img) |
| y = flash_attention(q, k_img, v_img, num_heads=self.num_heads) |
| x = x + y |
| return self.o(x) |
|
|
|
|
| class GateModule(nn.Module): |
| def __init__(self,): |
| super().__init__() |
|
|
| def forward(self, x, gate, residual): |
| return x + gate * residual |
|
|
| class DiTBlock(nn.Module): |
| def __init__(self, has_image_input: bool, dim: int, num_heads: int, ffn_dim: int, eps: float = 1e-6): |
| super().__init__() |
| self.dim = dim |
| self.num_heads = num_heads |
| self.ffn_dim = ffn_dim |
|
|
| self.self_attn = SelfAttention(dim, num_heads, eps) |
| self.cross_attn = CrossAttention( |
| dim, num_heads, eps, has_image_input=has_image_input) |
| self.norm1 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False) |
| self.norm2 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False) |
| self.norm3 = nn.LayerNorm(dim, eps=eps) |
| self.ffn = nn.Sequential(nn.Linear(dim, ffn_dim), nn.GELU( |
| approximate='tanh'), nn.Linear(ffn_dim, dim)) |
| self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5) |
| self.gate = GateModule() |
|
|
| def forward(self, x, context, t_mod, freqs, attn_mask=None, shot_latent_indices=None, per_g=0, ID_2_shot=None): |
| has_seq = len(t_mod.shape) == 4 |
| chunk_dim = 2 if has_seq else 1 |
| |
| shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( |
| self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(6, dim=chunk_dim) |
| if has_seq: |
| shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( |
| shift_msa.squeeze(2), scale_msa.squeeze(2), gate_msa.squeeze(2), |
| shift_mlp.squeeze(2), scale_mlp.squeeze(2), gate_mlp.squeeze(2), |
| ) |
| input_x = modulate(self.norm1(x), shift_msa, scale_msa) |
| x = self.gate(x, gate_msa, self.self_attn(input_x, freqs, shot_latent_indices=shot_latent_indices, per_g = per_g, ID_2_shot = ID_2_shot)) |
| x = x + self.cross_attn(self.norm3(x), context, attn_mask) |
| input_x = modulate(self.norm2(x), shift_mlp, scale_mlp) |
| x = self.gate(x, gate_mlp, self.ffn(input_x)) |
| return x |
|
|
|
|
| class MLP(torch.nn.Module): |
| def __init__(self, in_dim, out_dim, has_pos_emb=False): |
| super().__init__() |
| self.proj = torch.nn.Sequential( |
| nn.LayerNorm(in_dim), |
| nn.Linear(in_dim, in_dim), |
| nn.GELU(), |
| nn.Linear(in_dim, out_dim), |
| nn.LayerNorm(out_dim) |
| ) |
| self.has_pos_emb = has_pos_emb |
| if has_pos_emb: |
| self.emb_pos = torch.nn.Parameter(torch.zeros((1, 514, 1280))) |
|
|
| def forward(self, x): |
| if self.has_pos_emb: |
| x = x + self.emb_pos.to(dtype=x.dtype, device=x.device) |
| return self.proj(x) |
|
|
|
|
| class Head(nn.Module): |
| def __init__(self, dim: int, out_dim: int, patch_size: Tuple[int, int, int], eps: float): |
| super().__init__() |
| self.dim = dim |
| self.patch_size = patch_size |
| self.norm = nn.LayerNorm(dim, eps=eps, elementwise_affine=False) |
| self.head = nn.Linear(dim, out_dim * math.prod(patch_size)) |
| self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5) |
|
|
| def forward(self, x, t_mod): |
| if len(t_mod.shape) == 3: |
| shift, scale = (self.modulation.unsqueeze(0).to(dtype=t_mod.dtype, device=t_mod.device) + t_mod.unsqueeze(2)).chunk(2, dim=2) |
| x = (self.head(self.norm(x) * (1 + scale.squeeze(2)) + shift.squeeze(2))) |
| else: |
| shift, scale = (self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(2, dim=1) |
| x = (self.head(self.norm(x) * (1 + scale) + shift)) |
| return x |
|
|
|
|
| class WanModel(torch.nn.Module): |
| def __init__( |
| self, |
| dim: int, |
| in_dim: int, |
| ffn_dim: int, |
| out_dim: int, |
| text_dim: int, |
| freq_dim: int, |
| eps: float, |
| patch_size: Tuple[int, int, int], |
| num_heads: int, |
| num_layers: int, |
| has_image_input: bool, |
| has_image_pos_emb: bool = False, |
| has_ref_conv: bool = False, |
| add_control_adapter: bool = False, |
| in_dim_control_adapter: int = 24, |
| seperated_timestep: bool = False, |
| require_vae_embedding: bool = True, |
| require_clip_embedding: bool = True, |
| fuse_vae_embedding_in_latents: bool = False, |
| ): |
| super().__init__() |
| self.dim = dim |
| self.in_dim = in_dim |
| self.freq_dim = freq_dim |
| self.has_image_input = has_image_input |
| self.patch_size = patch_size |
| self.seperated_timestep = seperated_timestep |
| self.require_vae_embedding = require_vae_embedding |
| self.require_clip_embedding = require_clip_embedding |
| self.fuse_vae_embedding_in_latents = fuse_vae_embedding_in_latents |
|
|
| self.patch_embedding = nn.Conv3d( |
| in_dim, dim, kernel_size=patch_size, stride=patch_size) |
| self.text_embedding = nn.Sequential( |
| nn.Linear(text_dim, dim), |
| nn.GELU(approximate='tanh'), |
| nn.Linear(dim, dim) |
| ) |
| self.time_embedding = nn.Sequential( |
| nn.Linear(freq_dim, dim), |
| nn.SiLU(), |
| nn.Linear(dim, dim) |
| ) |
| self.time_projection = nn.Sequential( |
| nn.SiLU(), nn.Linear(dim, dim * 6)) |
| self.blocks = nn.ModuleList([ |
| DiTBlock(has_image_input, dim, num_heads, ffn_dim, eps) |
| for _ in range(num_layers) |
| ]) |
| self.head = Head(dim, out_dim, patch_size, eps) |
| head_dim = dim // num_heads |
| self.freqs = precompute_freqs_cis_3d(head_dim) |
| self.shot_freqs = precompute_freqs_cis_4d(head_dim) |
|
|
| if has_image_input: |
| self.img_emb = MLP(1280, dim, has_pos_emb=has_image_pos_emb) |
| if has_ref_conv: |
| self.ref_conv = nn.Conv2d(16, dim, kernel_size=(2, 2), stride=(2, 2)) |
| self.has_image_pos_emb = has_image_pos_emb |
| self.has_ref_conv = has_ref_conv |
| if add_control_adapter: |
| self.control_adapter = SimpleAdapter(in_dim_control_adapter, dim, kernel_size=patch_size[1:], stride=patch_size[1:]) |
| else: |
| self.control_adapter = None |
|
|
| def patchify(self, x: torch.Tensor, control_camera_latents_input: Optional[torch.Tensor] = None): |
| x = self.patch_embedding(x) |
| if self.control_adapter is not None and control_camera_latents_input is not None: |
| y_camera = self.control_adapter(control_camera_latents_input) |
| x = [u + v for u, v in zip(x, y_camera)] |
| x = x[0].unsqueeze(0) |
| grid_size = x.shape[2:] |
| x = rearrange(x, 'b c f h w -> b (f h w) c').contiguous() |
| return x, grid_size |
|
|
| def unpatchify(self, x: torch.Tensor, grid_size: torch.Tensor): |
| return rearrange( |
| x, 'b (f h w) (x y z c) -> b c (f x) (h y) (w z)', |
| f=grid_size[0], h=grid_size[1], w=grid_size[2], |
| x=self.patch_size[0], y=self.patch_size[1], z=self.patch_size[2] |
| ) |
|
|
| def forward(self, |
| x: torch.Tensor, |
| timestep: torch.Tensor, |
| context: torch.Tensor, |
| clip_feature: Optional[torch.Tensor] = None, |
| y: Optional[torch.Tensor] = None, |
| use_gradient_checkpointing: bool = False, |
| use_gradient_checkpointing_offload: bool = False, |
| **kwargs, |
| ): |
| t = self.time_embedding( |
| sinusoidal_embedding_1d(self.freq_dim, timestep)) |
| t_mod = self.time_projection(t).unflatten(1, (6, self.dim)) |
| context = self.text_embedding(context) |
| |
| if self.has_image_input: |
| x = torch.cat([x, y], dim=1) |
| clip_embdding = self.img_emb(clip_feature) |
| context = torch.cat([clip_embdding, context], dim=1) |
| |
| x, (f, h, w) = self.patchify(x) |
| |
| freqs = torch.cat([ |
| self.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), |
| self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), |
| self.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1) |
| ], dim=-1).reshape(f * h * w, 1, -1).to(x.device) |
| |
| def create_custom_forward(module): |
| def custom_forward(*inputs): |
| return module(*inputs) |
| return custom_forward |
|
|
| for block in self.blocks: |
| if self.training and use_gradient_checkpointing: |
| if use_gradient_checkpointing_offload: |
| with torch.autograd.graph.save_on_cpu(): |
| x = torch.utils.checkpoint.checkpoint( |
| create_custom_forward(block), |
| x, context, t_mod, freqs, |
| use_reentrant=False, |
| ) |
| else: |
| x = torch.utils.checkpoint.checkpoint( |
| create_custom_forward(block), |
| x, context, t_mod, freqs, |
| use_reentrant=False, |
| ) |
| else: |
| x = block(x, context, t_mod, freqs) |
|
|
| x = self.head(x, t) |
| x = self.unpatchify(x, (f, h, w)) |
| return x |
|
|
| @staticmethod |
| def state_dict_converter(): |
| return WanModelStateDictConverter() |
| |
| |
| class WanModelStateDictConverter: |
| def __init__(self): |
| pass |
|
|
| def from_diffusers(self, state_dict): |
| rename_dict = { |
| "blocks.0.attn1.norm_k.weight": "blocks.0.self_attn.norm_k.weight", |
| "blocks.0.attn1.norm_q.weight": "blocks.0.self_attn.norm_q.weight", |
| "blocks.0.attn1.to_k.bias": "blocks.0.self_attn.k.bias", |
| "blocks.0.attn1.to_k.weight": "blocks.0.self_attn.k.weight", |
| "blocks.0.attn1.to_out.0.bias": "blocks.0.self_attn.o.bias", |
| "blocks.0.attn1.to_out.0.weight": "blocks.0.self_attn.o.weight", |
| "blocks.0.attn1.to_q.bias": "blocks.0.self_attn.q.bias", |
| "blocks.0.attn1.to_q.weight": "blocks.0.self_attn.q.weight", |
| "blocks.0.attn1.to_v.bias": "blocks.0.self_attn.v.bias", |
| "blocks.0.attn1.to_v.weight": "blocks.0.self_attn.v.weight", |
| "blocks.0.attn2.norm_k.weight": "blocks.0.cross_attn.norm_k.weight", |
| "blocks.0.attn2.norm_q.weight": "blocks.0.cross_attn.norm_q.weight", |
| "blocks.0.attn2.to_k.bias": "blocks.0.cross_attn.k.bias", |
| "blocks.0.attn2.to_k.weight": "blocks.0.cross_attn.k.weight", |
| "blocks.0.attn2.to_out.0.bias": "blocks.0.cross_attn.o.bias", |
| "blocks.0.attn2.to_out.0.weight": "blocks.0.cross_attn.o.weight", |
| "blocks.0.attn2.to_q.bias": "blocks.0.cross_attn.q.bias", |
| "blocks.0.attn2.to_q.weight": "blocks.0.cross_attn.q.weight", |
| "blocks.0.attn2.to_v.bias": "blocks.0.cross_attn.v.bias", |
| "blocks.0.attn2.to_v.weight": "blocks.0.cross_attn.v.weight", |
| "blocks.0.ffn.net.0.proj.bias": "blocks.0.ffn.0.bias", |
| "blocks.0.ffn.net.0.proj.weight": "blocks.0.ffn.0.weight", |
| "blocks.0.ffn.net.2.bias": "blocks.0.ffn.2.bias", |
| "blocks.0.ffn.net.2.weight": "blocks.0.ffn.2.weight", |
| "blocks.0.norm2.bias": "blocks.0.norm3.bias", |
| "blocks.0.norm2.weight": "blocks.0.norm3.weight", |
| "blocks.0.scale_shift_table": "blocks.0.modulation", |
| "condition_embedder.text_embedder.linear_1.bias": "text_embedding.0.bias", |
| "condition_embedder.text_embedder.linear_1.weight": "text_embedding.0.weight", |
| "condition_embedder.text_embedder.linear_2.bias": "text_embedding.2.bias", |
| "condition_embedder.text_embedder.linear_2.weight": "text_embedding.2.weight", |
| "condition_embedder.time_embedder.linear_1.bias": "time_embedding.0.bias", |
| "condition_embedder.time_embedder.linear_1.weight": "time_embedding.0.weight", |
| "condition_embedder.time_embedder.linear_2.bias": "time_embedding.2.bias", |
| "condition_embedder.time_embedder.linear_2.weight": "time_embedding.2.weight", |
| "condition_embedder.time_proj.bias": "time_projection.1.bias", |
| "condition_embedder.time_proj.weight": "time_projection.1.weight", |
| "patch_embedding.bias": "patch_embedding.bias", |
| "patch_embedding.weight": "patch_embedding.weight", |
| "scale_shift_table": "head.modulation", |
| "proj_out.bias": "head.head.bias", |
| "proj_out.weight": "head.head.weight", |
| } |
| state_dict_ = {} |
| for name, param in state_dict.items(): |
| if name in rename_dict: |
| state_dict_[rename_dict[name]] = param |
| else: |
| name_ = ".".join(name.split(".")[:1] + ["0"] + name.split(".")[2:]) |
| if name_ in rename_dict: |
| name_ = rename_dict[name_] |
| name_ = ".".join(name_.split(".")[:1] + [name.split(".")[1]] + name_.split(".")[2:]) |
| state_dict_[name_] = param |
| if hash_state_dict_keys(state_dict) == "cb104773c6c2cb6df4f9529ad5c60d0b": |
| config = { |
| "model_type": "t2v", |
| "patch_size": (1, 2, 2), |
| "text_len": 512, |
| "in_dim": 16, |
| "dim": 5120, |
| "ffn_dim": 13824, |
| "freq_dim": 256, |
| "text_dim": 4096, |
| "out_dim": 16, |
| "num_heads": 40, |
| "num_layers": 40, |
| "window_size": (-1, -1), |
| "qk_norm": True, |
| "cross_attn_norm": True, |
| "eps": 1e-6, |
| } |
| else: |
| config = {} |
| return state_dict_, config |
| |
| def from_civitai(self, state_dict): |
| state_dict = {name: param for name, param in state_dict.items() if not name.startswith("vace")} |
| if hash_state_dict_keys(state_dict) == "9269f8db9040a9d860eaca435be61814": |
| config = { |
| "has_image_input": False, |
| "patch_size": [1, 2, 2], |
| "in_dim": 16, |
| "dim": 1536, |
| "ffn_dim": 8960, |
| "freq_dim": 256, |
| "text_dim": 4096, |
| "out_dim": 16, |
| "num_heads": 12, |
| "num_layers": 30, |
| "eps": 1e-6 |
| } |
| elif hash_state_dict_keys(state_dict) == "aafcfd9672c3a2456dc46e1cb6e52c70": |
| config = { |
| "has_image_input": False, |
| "patch_size": [1, 2, 2], |
| "in_dim": 16, |
| "dim": 5120, |
| "ffn_dim": 13824, |
| "freq_dim": 256, |
| "text_dim": 4096, |
| "out_dim": 16, |
| "num_heads": 40, |
| "num_layers": 40, |
| "eps": 1e-6 |
| } |
| elif hash_state_dict_keys(state_dict) == "6bfcfb3b342cb286ce886889d519a77e": |
| config = { |
| "has_image_input": True, |
| "patch_size": [1, 2, 2], |
| "in_dim": 36, |
| "dim": 5120, |
| "ffn_dim": 13824, |
| "freq_dim": 256, |
| "text_dim": 4096, |
| "out_dim": 16, |
| "num_heads": 40, |
| "num_layers": 40, |
| "eps": 1e-6 |
| } |
| elif hash_state_dict_keys(state_dict) == "6d6ccde6845b95ad9114ab993d917893": |
| config = { |
| "has_image_input": True, |
| "patch_size": [1, 2, 2], |
| "in_dim": 36, |
| "dim": 1536, |
| "ffn_dim": 8960, |
| "freq_dim": 256, |
| "text_dim": 4096, |
| "out_dim": 16, |
| "num_heads": 12, |
| "num_layers": 30, |
| "eps": 1e-6 |
| } |
| elif hash_state_dict_keys(state_dict) == "6bfcfb3b342cb286ce886889d519a77e": |
| config = { |
| "has_image_input": True, |
| "patch_size": [1, 2, 2], |
| "in_dim": 36, |
| "dim": 5120, |
| "ffn_dim": 13824, |
| "freq_dim": 256, |
| "text_dim": 4096, |
| "out_dim": 16, |
| "num_heads": 40, |
| "num_layers": 40, |
| "eps": 1e-6 |
| } |
| elif hash_state_dict_keys(state_dict) == "349723183fc063b2bfc10bb2835cf677": |
| |
| config = { |
| "has_image_input": True, |
| "patch_size": [1, 2, 2], |
| "in_dim": 48, |
| "dim": 1536, |
| "ffn_dim": 8960, |
| "freq_dim": 256, |
| "text_dim": 4096, |
| "out_dim": 16, |
| "num_heads": 12, |
| "num_layers": 30, |
| "eps": 1e-6 |
| } |
| elif hash_state_dict_keys(state_dict) == "efa44cddf936c70abd0ea28b6cbe946c": |
| |
| config = { |
| "has_image_input": True, |
| "patch_size": [1, 2, 2], |
| "in_dim": 48, |
| "dim": 5120, |
| "ffn_dim": 13824, |
| "freq_dim": 256, |
| "text_dim": 4096, |
| "out_dim": 16, |
| "num_heads": 40, |
| "num_layers": 40, |
| "eps": 1e-6 |
| } |
| elif hash_state_dict_keys(state_dict) == "3ef3b1f8e1dab83d5b71fd7b617f859f": |
| config = { |
| "has_image_input": True, |
| "patch_size": [1, 2, 2], |
| "in_dim": 36, |
| "dim": 5120, |
| "ffn_dim": 13824, |
| "freq_dim": 256, |
| "text_dim": 4096, |
| "out_dim": 16, |
| "num_heads": 40, |
| "num_layers": 40, |
| "eps": 1e-6, |
| "has_image_pos_emb": True |
| } |
| elif hash_state_dict_keys(state_dict) == "70ddad9d3a133785da5ea371aae09504": |
| |
| config = { |
| "has_image_input": True, |
| "patch_size": [1, 2, 2], |
| "in_dim": 48, |
| "dim": 1536, |
| "ffn_dim": 8960, |
| "freq_dim": 256, |
| "text_dim": 4096, |
| "out_dim": 16, |
| "num_heads": 12, |
| "num_layers": 30, |
| "eps": 1e-6, |
| "has_ref_conv": True |
| } |
| elif hash_state_dict_keys(state_dict) == "26bde73488a92e64cc20b0a7485b9e5b": |
| |
| config = { |
| "has_image_input": True, |
| "patch_size": [1, 2, 2], |
| "in_dim": 48, |
| "dim": 5120, |
| "ffn_dim": 13824, |
| "freq_dim": 256, |
| "text_dim": 4096, |
| "out_dim": 16, |
| "num_heads": 40, |
| "num_layers": 40, |
| "eps": 1e-6, |
| "has_ref_conv": True |
| } |
| elif hash_state_dict_keys(state_dict) == "ac6a5aa74f4a0aab6f64eb9a72f19901": |
| |
| config = { |
| "has_image_input": True, |
| "patch_size": [1, 2, 2], |
| "in_dim": 32, |
| "dim": 1536, |
| "ffn_dim": 8960, |
| "freq_dim": 256, |
| "text_dim": 4096, |
| "out_dim": 16, |
| "num_heads": 12, |
| "num_layers": 30, |
| "eps": 1e-6, |
| "has_ref_conv": False, |
| "add_control_adapter": True, |
| "in_dim_control_adapter": 24, |
| } |
| elif hash_state_dict_keys(state_dict) == "b61c605c2adbd23124d152ed28e049ae": |
| |
| config = { |
| "has_image_input": True, |
| "patch_size": [1, 2, 2], |
| "in_dim": 32, |
| "dim": 5120, |
| "ffn_dim": 13824, |
| "freq_dim": 256, |
| "text_dim": 4096, |
| "out_dim": 16, |
| "num_heads": 40, |
| "num_layers": 40, |
| "eps": 1e-6, |
| "has_ref_conv": False, |
| "add_control_adapter": True, |
| "in_dim_control_adapter": 24, |
| } |
| elif hash_state_dict_keys(state_dict) == "1f5ab7703c6fc803fdded85ff040c316": |
| |
| config = { |
| "has_image_input": False, |
| "patch_size": [1, 2, 2], |
| "in_dim": 48, |
| "dim": 3072, |
| "ffn_dim": 14336, |
| "freq_dim": 256, |
| "text_dim": 4096, |
| "out_dim": 48, |
| "num_heads": 24, |
| "num_layers": 30, |
| "eps": 1e-6, |
| "seperated_timestep": True, |
| "require_clip_embedding": False, |
| "require_vae_embedding": False, |
| "fuse_vae_embedding_in_latents": True, |
| } |
| elif hash_state_dict_keys(state_dict) == "5b013604280dd715f8457c6ed6d6a626": |
| |
| config = { |
| "has_image_input": False, |
| "patch_size": [1, 2, 2], |
| "in_dim": 36, |
| "dim": 5120, |
| "ffn_dim": 13824, |
| "freq_dim": 256, |
| "text_dim": 4096, |
| "out_dim": 16, |
| "num_heads": 40, |
| "num_layers": 40, |
| "eps": 1e-6, |
| "require_clip_embedding": False, |
| } |
| elif hash_state_dict_keys(state_dict) == "2267d489f0ceb9f21836532952852ee5": |
| |
| config = { |
| "has_image_input": False, |
| "patch_size": [1, 2, 2], |
| "in_dim": 52, |
| "dim": 5120, |
| "ffn_dim": 13824, |
| "freq_dim": 256, |
| "text_dim": 4096, |
| "out_dim": 16, |
| "num_heads": 40, |
| "num_layers": 40, |
| "eps": 1e-6, |
| "has_ref_conv": True, |
| "require_clip_embedding": False, |
| } |
| elif hash_state_dict_keys(state_dict) == "47dbeab5e560db3180adf51dc0232fb1": |
| |
| config = { |
| "has_image_input": False, |
| "patch_size": [1, 2, 2], |
| "in_dim": 36, |
| "dim": 5120, |
| "ffn_dim": 13824, |
| "freq_dim": 256, |
| "text_dim": 4096, |
| "out_dim": 16, |
| "num_heads": 40, |
| "num_layers": 40, |
| "eps": 1e-6, |
| "has_ref_conv": False, |
| "add_control_adapter": True, |
| "in_dim_control_adapter": 24, |
| "require_clip_embedding": False, |
| } |
| else: |
| config = {} |
| return state_dict, config |
|
|