lingbot-va / wan_va /modules /model.py
bazaar-research's picture
Upload folder using huggingface_hub
0a7036f verified
# Copyright 2024-2025 The Robbyant Team Authors. All rights reserved.
import math
from copy import deepcopy
import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.attention import FeedForward
from diffusers.models.embeddings import (
PixArtAlphaTextProjection,
TimestepEmbedding,
Timesteps,
)
from diffusers.models.modeling_utils import ModelMixin
from diffusers.models.normalization import FP32LayerNorm
from einops import rearrange
try:
from flash_attn_interface import flash_attn_func
except:
from flash_attn import flash_attn_func
__all__ = ['WanTransformer3DModel']
def custom_sdpa(q, k, v):
out = F.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2),
v.transpose(1, 2))
return out.transpose(1, 2)
class WanTimeTextImageEmbedding(nn.Module):
def __init__(
self,
dim,
time_freq_dim,
time_proj_dim,
text_embed_dim,
pos_embed_seq_len,
):
super().__init__()
self.timesteps_proj = Timesteps(num_channels=time_freq_dim,
flip_sin_to_cos=True,
downscale_freq_shift=0)
self.time_embedder = TimestepEmbedding(in_channels=time_freq_dim,
time_embed_dim=dim)
self.act_fn = nn.SiLU()
self.time_proj = nn.Linear(dim, time_proj_dim)
self.text_embedder = PixArtAlphaTextProjection(text_embed_dim,
dim,
act_fn="gelu_tanh")
def forward(
self,
timestep: torch.Tensor,
dtype=None,
):
B, L = timestep.shape
timestep = timestep.reshape(-1)
timestep = self.timesteps_proj(timestep)
# time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype
time_embedder_dtype = self.time_embedder.linear_1.weight.dtype
if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8:
timestep = timestep.to(time_embedder_dtype)
temb = self.time_embedder(timestep).to(dtype=dtype)
timestep_proj = self.time_proj(self.act_fn(temb))
return temb.reshape(B, L, -1), timestep_proj.reshape(B, L, -1)
class WanRotaryPosEmbed(nn.Module):
def __init__(
self,
attention_head_dim,
patch_size,
max_seq_len,
theta=10000.0,
):
super().__init__()
self.attention_head_dim = attention_head_dim
self.patch_size = patch_size
self.max_seq_len = max_seq_len
self.theta = theta
self.f_dim = self.attention_head_dim - 2 * (self.attention_head_dim //
3)
self.h_dim = self.attention_head_dim // 3
self.w_dim = self.attention_head_dim // 3
# Precompute and register buffers
f_freqs_base, h_freqs_base, w_freqs_base = self._precompute_freqs_base(
)
self.register_buffer("f_freqs_base", f_freqs_base, persistent=False)
self.register_buffer("h_freqs_base", h_freqs_base, persistent=False)
self.register_buffer("w_freqs_base", w_freqs_base, persistent=False)
def _precompute_freqs_base(self):
# freqs_base = 1.0 / (theta ** (2k / dim))
f_freqs_base = 1.0 / (self.theta**(torch.arange(
0, self.f_dim, 2)[:(self.f_dim // 2)].double() / self.f_dim))
h_freqs_base = 1.0 / (self.theta**(torch.arange(
0, self.h_dim, 2)[:(self.h_dim // 2)].double() / self.h_dim))
w_freqs_base = 1.0 / (self.theta**(torch.arange(
0, self.w_dim, 2)[:(self.w_dim // 2)].double() / self.w_dim))
return f_freqs_base, h_freqs_base, w_freqs_base
def forward(self, grid_ids):
with torch.no_grad():
f_freqs = grid_ids[:, 0, :].unsqueeze(-1) * self.f_freqs_base
h_freqs = grid_ids[:, 1, :].unsqueeze(-1) * self.h_freqs_base
w_freqs = grid_ids[:, 2, :].unsqueeze(-1) * self.w_freqs_base
freqs = torch.cat([f_freqs, h_freqs, w_freqs], dim=-1).float()
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
return freqs_cis
class WanAttention(torch.nn.Module):
def __init__(
self,
dim,
heads=8,
dim_head=64,
eps=1e-5,
dropout=0.0,
cross_attention_dim_head=None,
attn_mode='torch',
):
super().__init__()
if attn_mode == 'torch':
self.attn_op = custom_sdpa
elif attn_mode == 'flashattn':
self.attn_op = flash_attn_func
else:
raise ValueError(
f"Unsupported attention mode: {attn_mode}, only support torch and flashattn"
)
self.inner_dim = dim_head * heads
self.heads = heads
self.cross_attention_dim_head = cross_attention_dim_head
self.kv_inner_dim = self.inner_dim if cross_attention_dim_head is None else cross_attention_dim_head * heads
self.to_q = torch.nn.Linear(dim, self.inner_dim, bias=True)
self.to_k = torch.nn.Linear(dim, self.kv_inner_dim, bias=True)
self.to_v = torch.nn.Linear(dim, self.kv_inner_dim, bias=True)
self.to_out = torch.nn.ModuleList([
torch.nn.Linear(self.inner_dim, dim, bias=True),
torch.nn.Dropout(dropout),
])
self.norm_q = torch.nn.RMSNorm(dim_head * heads,
eps=eps,
elementwise_affine=True)
self.norm_k = torch.nn.RMSNorm(dim_head * heads,
eps=eps,
elementwise_affine=True)
self.attn_caches = {} if cross_attention_dim_head is None else None
def clear_pred_cache(self, cache_name):
if self.attn_caches is None:
return
cache = self.attn_caches[cache_name]
is_pred = cache['is_pred']
cache['mask'][is_pred] = False
def clear_cache(self, cache_name):
if self.attn_caches is None:
return
self.attn_caches[cache_name] = None
def init_kv_cache(self, cache_name, total_tolen, num_head, head_dim,
device, dtype, batch_size):
if self.attn_caches is None:
return
self.attn_caches[cache_name] = {
'k':
torch.empty([batch_size, total_tolen, num_head, head_dim],
device=device,
dtype=dtype),
'v':
torch.empty([batch_size, total_tolen, num_head, head_dim],
device=device,
dtype=dtype),
'id':
torch.full((total_tolen, ), -1, device=device),
"mask":
torch.zeros((total_tolen, ), dtype=torch.bool, device=device),
"is_pred":
torch.zeros((total_tolen, ), dtype=torch.bool, device=device),
}
def allocate_slots(self, cache_name, key_size):
cache = self.attn_caches[cache_name]
mask = cache["mask"]
ids = cache["id"]
free = (~mask).nonzero(as_tuple=False).squeeze(-1)
if free.numel() < key_size:
used = mask.nonzero(as_tuple=False).squeeze(-1)
used_ids = ids[used]
order = torch.argsort(used_ids)
need = key_size - free.numel()
to_free = used[order[:need]]
mask[to_free] = False
ids[to_free] = -1
free = (~mask).nonzero(as_tuple=False).squeeze(-1)
assert free.numel() >= key_size
return free[:key_size]
def _next_cache_id(self, cache_name):
ids = self.attn_caches[cache_name]['id']
mask = self.attn_caches[cache_name]['mask']
if mask.any():
return ids[mask].max() + 1
else:
return torch.tensor(0, device=ids.device, dtype=ids.dtype)
def update_cache(self, cache_name, key, value, is_pred):
cache = self.attn_caches[cache_name]
key_size = key.shape[1]
slots = self.allocate_slots(cache_name, key_size)
new_id = self._next_cache_id(cache_name)
cache['k'][:, slots] = key
cache['v'][:, slots] = value
cache['mask'][slots] = True
cache['id'][slots] = new_id
cache['is_pred'][slots] = is_pred
return slots
def restore_cache(self, cache_name, slots):
self.attn_caches[cache_name]['mask'][slots] = False
def forward(
self,
q,
k,
v,
rotary_emb,
update_cache=0,
cache_name='pos',
):
kv_cache = self.attn_caches[
cache_name] if self.attn_caches is not None else None
query, key, value = self.to_q(q), self.to_k(k), self.to_v(v)
query = self.norm_q(query)
query = query.unflatten(2, (self.heads, -1))
key = self.norm_k(key)
key = key.unflatten(2, (self.heads, -1))
value = value.unflatten(2, (self.heads, -1))
if rotary_emb is not None:
def apply_rotary_emb(x, freqs):
x_out = torch.view_as_complex(
x.to(torch.float64).reshape(x.shape[0], x.shape[1],
x.shape[2], -1, 2))
x_out = torch.view_as_real(x_out * freqs).flatten(3)
return x_out.to(x.dtype)
query = apply_rotary_emb(query, rotary_emb)
key = apply_rotary_emb(key, rotary_emb)
slots = None
if kv_cache is not None and kv_cache['k'] is not None:
slots = self.update_cache(cache_name,
key,
value,
is_pred=(update_cache == 1))
key_pool = self.attn_caches[cache_name]['k']
value_pool = self.attn_caches[cache_name]['v']
mask = self.attn_caches[cache_name]['mask']
valid = mask.nonzero(as_tuple=False).squeeze(-1)
key = key_pool[:, valid]
value = value_pool[:, valid]
hidden_states = self.attn_op(query, key, value)
if update_cache == 0:
if kv_cache is not None and kv_cache['k'] is not None:
self.restore_cache(cache_name, slots)
hidden_states = hidden_states.flatten(2, 3)
hidden_states = hidden_states.type_as(query)
hidden_states = self.to_out[0](hidden_states)
hidden_states = self.to_out[1](hidden_states)
return hidden_states
class WanTransformerBlock(nn.Module):
def __init__(
self,
dim,
ffn_dim,
num_heads,
cross_attn_norm=False,
eps=1e-6,
attn_mode: str = "flashattn",
):
super().__init__()
self.attn_mode = attn_mode
# 1. Self-attention
self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False)
self.attn1 = WanAttention(
dim=dim,
heads=num_heads,
dim_head=dim // num_heads,
eps=eps,
cross_attention_dim_head=None,
attn_mode=attn_mode,
)
# 2. Cross-attention
self.attn2 = WanAttention(
dim=dim,
heads=num_heads,
dim_head=dim // num_heads,
eps=eps,
cross_attention_dim_head=dim // num_heads,
attn_mode=attn_mode,
)
self.norm2 = FP32LayerNorm(
dim, eps,
elementwise_affine=True) if cross_attn_norm else nn.Identity()
# 3. Feed-forward
self.ffn = FeedForward(dim,
inner_dim=ffn_dim,
activation_fn="gelu-approximate")
self.norm3 = FP32LayerNorm(dim, eps, elementwise_affine=False)
self.scale_shift_table = nn.Parameter(
torch.randn(1, 6, dim) / dim**0.5)
def forward(
self,
hidden_states,
encoder_hidden_states,
temb,
rotary_emb,
update_cache=0,
cache_name='pos',
) -> torch.Tensor:
temb_scale_shift_table = self.scale_shift_table[None] + temb.float()
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = \
rearrange(temb_scale_shift_table, 'b l n c -> b n l c').chunk(6, dim=1)
shift_msa = shift_msa.squeeze(1)
scale_msa = scale_msa.squeeze(1)
gate_msa = gate_msa.squeeze(1)
c_shift_msa = c_shift_msa.squeeze(1)
c_scale_msa = c_scale_msa.squeeze(1)
c_gate_msa = c_gate_msa.squeeze(1)
# 1. Self-attention
norm_hidden_states = (self.norm1(hidden_states.float()) *
(1. + scale_msa) +
shift_msa).type_as(hidden_states)
attn_output = self.attn1(norm_hidden_states,
norm_hidden_states,
norm_hidden_states,
rotary_emb,
update_cache=update_cache,
cache_name=cache_name)
hidden_states = (hidden_states.float() +
attn_output * gate_msa).type_as(hidden_states)
# 2. Cross-attention
norm_hidden_states = self.norm2(
hidden_states.float()).type_as(hidden_states)
attn_output = self.attn2(norm_hidden_states,
encoder_hidden_states,
encoder_hidden_states,
None,
update_cache=0,
cache_name=cache_name)
hidden_states = hidden_states + attn_output
# 3. Feed-forward
norm_hidden_states = (self.norm3(hidden_states.float()) *
(1. + c_scale_msa) +
c_shift_msa).type_as(hidden_states)
ff_output = self.ffn(norm_hidden_states)
hidden_states = (hidden_states.float() +
ff_output.float() * c_gate_msa).type_as(hidden_states)
return hidden_states
class WanTransformer3DModel(ModelMixin, ConfigMixin):
r"""
TODO
"""
@register_to_config
def __init__(self,
patch_size=[1, 2, 2],
num_attention_heads=24,
attention_head_dim=128,
in_channels=48,
out_channels=48,
action_dim=30,
text_dim=4096,
freq_dim=256,
ffn_dim=14336,
num_layers=30,
cross_attn_norm=True,
eps=1e-06,
rope_max_seq_len=1024,
pos_embed_seq_len=None,
attn_mode="torch"):
r"""
TODO
"""
super().__init__()
self.patch_size = patch_size
self.num_attention_heads = num_attention_heads
self.attention_head_dim = attention_head_dim
inner_dim = num_attention_heads * attention_head_dim
self.rope = WanRotaryPosEmbed(attention_head_dim, patch_size,
rope_max_seq_len)
self.patch_embedding_mlp = nn.Linear(
in_channels * patch_size[0] * patch_size[1] * patch_size[2],
inner_dim)
self.action_embedder = nn.Linear(action_dim, inner_dim)
self.condition_embedder = WanTimeTextImageEmbedding(
dim=inner_dim,
time_freq_dim=freq_dim,
time_proj_dim=inner_dim * 6,
text_embed_dim=text_dim,
pos_embed_seq_len=pos_embed_seq_len,
)
self.condition_embedder_action = deepcopy(self.condition_embedder)
self.blocks = nn.ModuleList([
WanTransformerBlock(inner_dim,
ffn_dim,
num_attention_heads,
cross_attn_norm,
eps,
attn_mode=attn_mode) for _ in range(num_layers)
])
self.norm_out = FP32LayerNorm(inner_dim, eps, elementwise_affine=False)
self.proj_out = nn.Linear(inner_dim,
out_channels * math.prod(patch_size))
self.action_proj_out = nn.Linear(inner_dim, action_dim)
self.scale_shift_table = nn.Parameter(
torch.randn(1, 2, inner_dim) / inner_dim**0.5)
def clear_cache(self, cache_name):
for block in self.blocks:
block.attn1.clear_cache(cache_name)
def clear_pred_cache(self, cache_name):
for block in self.blocks:
block.attn1.clear_pred_cache(cache_name)
def create_empty_cache(self, cache_name, attn_window,
latent_token_per_chunk, action_token_per_chunk,
device, dtype, batch_size):
total_tolen = (attn_window // 2) * latent_token_per_chunk + (
attn_window // 2) * action_token_per_chunk
for block in self.blocks:
block.attn1.init_kv_cache(cache_name, total_tolen,
self.num_attention_heads,
self.attention_head_dim, device, dtype, batch_size)
def forward(
self,
input_dict,
update_cache=0,
cache_name="pos",
action_mode=False,
):
r"""
Forward pass through the diffusion model
Args:
x (List[Tensor]):
List of input video tensors, each with shape [C_in, F, H, W]
t (Tensor):
Diffusion timesteps tensor of shape [B]
context (List[Tensor]):
List of text embeddings each with shape [L, C]
seq_len (`int`):
Maximum sequence length for positional encoding
y (List[Tensor], *optional*):
Conditional video inputs for image-to-video mode, same shape as x
Returns:
List[Tensor]:
List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8]
"""
if action_mode: # action input emb
latent_hidden_states = rearrange(input_dict['noisy_latents'],
'b c f h w -> b (f h w) c')
latent_hidden_states = self.action_embedder(
latent_hidden_states) # B L1 C
else: # latent input emb
latent_hidden_states = rearrange(
input_dict['noisy_latents'],
'b c (f p1) (h p2) (w p3) -> b (f h w) (c p1 p2 p3)',
p1=self.patch_size[0],
p2=self.patch_size[1],
p3=self.patch_size[2])
latent_hidden_states = self.patch_embedding_mlp(
latent_hidden_states)
text_hidden_states = self.condition_embedder.text_embedder(
input_dict["text_emb"]) # B L2 C
latent_grid_id = input_dict['grid_id']
rotary_emb = self.rope(latent_grid_id)[:, :, None] # 1 L 1 C
pach_scale_h, pach_scale_w = (1, 1) if action_mode else (
self.patch_size[1], self.patch_size[2])
latent_time_steps = torch.repeat_interleave(
input_dict['timesteps'],
(input_dict['noisy_latents'].shape[-2] // pach_scale_h) *
(input_dict['noisy_latents'].shape[-1] // pach_scale_w), dim=1) # L
current_condition_embedder = self.condition_embedder_action if action_mode else self.condition_embedder
temb, timestep_proj = current_condition_embedder(
latent_time_steps, dtype=latent_hidden_states.dtype)
timestep_proj = timestep_proj.unflatten(2, (6, -1)) # B L 6 C
for block in self.blocks:
latent_hidden_states = block(latent_hidden_states,
text_hidden_states,
timestep_proj,
rotary_emb,
update_cache=update_cache,
cache_name=cache_name)
temb_scale_shift_table = self.scale_shift_table[None] + temb[:, :, None, ...]
shift, scale = rearrange(temb_scale_shift_table,
'b l n c -> b n l c').chunk(2, dim=1)
shift = shift.to(latent_hidden_states.device).squeeze(1)
scale = scale.to(latent_hidden_states.device).squeeze(1)
latent_hidden_states = (self.norm_out(latent_hidden_states.float()) *
(1. + scale) +
shift).type_as(latent_hidden_states)
if action_mode:
latent_hidden_states = self.action_proj_out(latent_hidden_states)
else:
latent_hidden_states = self.proj_out(latent_hidden_states)
latent_hidden_states = rearrange(latent_hidden_states,
'b l (n c) -> b (l n) c',
n=math.prod(self.patch_size)) #
return latent_hidden_states
if __name__ == '__main__':
model = WanTransformer3DModel(patch_size=[1, 2, 2],
num_attention_heads=24,
attention_head_dim=128,
in_channels=48,
out_channels=48,
action_dim=30,
text_dim=4096,
freq_dim=256,
ffn_dim=14336,
num_layers=30,
cross_attn_norm=True,
eps=1e-6,
rope_max_seq_len=1024,
pos_embed_seq_len=None,
attn_mode="torch")
print(model)