stevengrove's picture
Initial commit with Xet-tracked image assets
fcfea15
from typing import Any, List, Tuple, Optional, Union, Dict
from einops import rearrange
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers.models import ModelMixin
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.attention import FeedForward
from diffusers.models.embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps
from modules.models.attention import attention, get_cu_seqlens, get_preferred_attention_backend
from .posemb_layers import apply_rotary_emb, get_nd_rotary_pos_embed
from .modulate_layers import load_modulation, modulate, apply_gate
class RMSNorm(nn.Module):
def __init__(
self,
dim: int,
elementwise_affine=True,
eps: float = 1e-6,
device=None,
dtype=None,
):
"""
Initialize the RMSNorm normalization layer.
Args:
dim (int): The dimension of the input tensor.
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
Attributes:
eps (float): A small value added to the denominator for numerical stability.
weight (nn.Parameter): Learnable scaling parameter.
"""
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.eps = eps
if elementwise_affine:
self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs))
def _norm(self, x):
"""
Apply the RMSNorm normalization to the input tensor.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The normalized tensor.
"""
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
"""
Forward pass through the RMSNorm layer.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The output tensor after applying RMSNorm.
"""
output = self._norm(x.float()).type_as(x)
if hasattr(self, "weight"):
output = output * self.weight
return output
class MMDoubleStreamBlock(nn.Module):
"""
A multimodal dit block with seperate modulation for
text and image/video, see more details (SD3): https://arxiv.org/abs/2403.03206
(Flux.1): https://github.com/black-forest-labs/flux
"""
def __init__(
self,
hidden_size: int,
heads_num: int,
mlp_width_ratio: float,
mlp_act_type: str = "gelu_tanh",
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
dit_modulation_type: Optional[str] = "wanx",
attn_backend: str = 'auto',
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.attn_backend = get_preferred_attention_backend() if attn_backend == 'auto' else attn_backend
self.dit_modulation_type = dit_modulation_type
self.heads_num = heads_num
head_dim = hidden_size // heads_num
mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
self.img_mod = load_modulation(
modulate_type=self.dit_modulation_type,
hidden_size=hidden_size,
factor=6,
**factory_kwargs,
)
self.img_norm1 = nn.LayerNorm(
hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs
)
self.img_attn_qkv = nn.Linear(
hidden_size, hidden_size * 3, bias=True, **factory_kwargs
)
self.img_attn_q_norm = RMSNorm(head_dim, elementwise_affine=True,
eps=1e-6, **factory_kwargs)
self.img_attn_k_norm = RMSNorm(head_dim, elementwise_affine=True,
eps=1e-6, **factory_kwargs)
self.img_attn_proj = nn.Linear(
hidden_size, hidden_size, bias=True, **factory_kwargs
)
self.img_norm2 = nn.LayerNorm(
hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs
)
# There is no dtype fpr FeedForward, because FSDP2 casts the dtype for all parameters.
# You may need to give the dtype when no autocast and fsdp !!!
self.img_mlp = FeedForward(hidden_size, inner_dim=mlp_hidden_dim,
activation_fn="gelu-approximate")
self.txt_mod = load_modulation(
modulate_type=self.dit_modulation_type,
hidden_size=hidden_size,
factor=6,
**factory_kwargs,
)
self.txt_norm1 = nn.LayerNorm(
hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs
)
self.txt_attn_qkv = nn.Linear(
hidden_size, hidden_size * 3, bias=True, **factory_kwargs
)
self.txt_attn_q_norm = RMSNorm(head_dim, elementwise_affine=True,
eps=1e-6, **factory_kwargs)
self.txt_attn_k_norm = RMSNorm(head_dim, elementwise_affine=True,
eps=1e-6, **factory_kwargs)
self.txt_attn_proj = nn.Linear(
hidden_size, hidden_size, bias=True, **factory_kwargs
)
self.txt_norm2 = nn.LayerNorm(
hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs
)
self.txt_mlp = FeedForward(hidden_size, inner_dim=mlp_hidden_dim,
activation_fn="gelu-approximate")
def forward(
self,
img: torch.Tensor,
txt: torch.Tensor,
vec: torch.Tensor,
vis_freqs_cis: tuple = None,
txt_freqs_cis: tuple = None,
attn_kwargs: Optional[dict] = {},
) -> Tuple[torch.Tensor, torch.Tensor]:
tt, th, tw = attn_kwargs['thw']
(
img_mod1_shift,
img_mod1_scale,
img_mod1_gate,
img_mod2_shift,
img_mod2_scale,
img_mod2_gate,
) = self.img_mod(vec)
(
txt_mod1_shift,
txt_mod1_scale,
txt_mod1_gate,
txt_mod2_shift,
txt_mod2_scale,
txt_mod2_gate,
) = self.txt_mod(vec)
# Prepare image for attention.
img_modulated = self.img_norm1(img)
img_modulated = modulate(
img_modulated, shift=img_mod1_shift, scale=img_mod1_scale
)
img_qkv = self.img_attn_qkv(img_modulated)
img_q, img_k, img_v = rearrange(
img_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num
)
# Apply QK-Norm if needed
img_q = self.img_attn_q_norm(img_q).to(img_v)
img_k = self.img_attn_k_norm(img_k).to(img_v)
# Apply RoPE if needed.
if vis_freqs_cis is not None:
img_qq, img_kk = apply_rotary_emb(
img_q, img_k, vis_freqs_cis, head_first=False)
assert (
img_qq.shape == img_q.shape and img_kk.shape == img_k.shape
), f"img_kk: {img_qq.shape}, img_q: {img_q.shape}, img_kk: {img_kk.shape}, img_k: {img_k.shape}"
img_q, img_k = img_qq, img_kk
# Prepare txt for attention.
txt_modulated = self.txt_norm1(txt)
txt_modulated = modulate(
txt_modulated, shift=txt_mod1_shift, scale=txt_mod1_scale
)
txt_qkv = self.txt_attn_qkv(txt_modulated)
txt_q, txt_k, txt_v = rearrange(
txt_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num
)
# Apply QK-Norm if needed.
txt_q = self.txt_attn_q_norm(txt_q).to(txt_v)
txt_k = self.txt_attn_k_norm(txt_k).to(txt_v)
if txt_freqs_cis is not None:
raise NotImplementedError("RoPE text is not supported for inference")
txt_qq, txt_kk = apply_rotary_emb(
txt_q, txt_k, txt_freqs_cis, head_first=False)
assert (
txt_qq.shape == txt_q.shape and txt_kk.shape == txt_k.shape
), f"txt_kk: {txt_qq.shape}, txt_q: {txt_q.shape}, txt_kk: {txt_kk.shape}, txt_k: {txt_k.shape}"
txt_q, txt_k = txt_qq, txt_kk
# attention computation start
q = torch.cat((img_q, txt_q), dim=1)
k = torch.cat((img_k, txt_k), dim=1)
v = torch.cat((img_v, txt_v), dim=1)
attn = attention(
q, k, v,
backend=self.attn_backend,
attn_kwargs=attn_kwargs,
)
attn = attn.flatten(2, 3)
# attention computation end
img_attn, txt_attn = attn[:,
: img.shape[1]], attn[:, img.shape[1]:]
# Calculate the img bloks.
img = img + apply_gate(self.img_attn_proj(img_attn),
gate=img_mod1_gate)
img = img + apply_gate(
self.img_mlp(
modulate(
self.img_norm2(img), shift=img_mod2_shift, scale=img_mod2_scale
)
),
gate=img_mod2_gate,
)
# Calculate the txt bloks.
txt = txt + apply_gate(self.txt_attn_proj(txt_attn),
gate=txt_mod1_gate)
txt = txt + apply_gate(
self.txt_mlp(
modulate(
self.txt_norm2(txt), shift=txt_mod2_shift, scale=txt_mod2_scale
)
),
gate=txt_mod2_gate,
)
return img, txt
class WanTimeTextImageEmbedding(nn.Module):
def __init__(
self,
dim: int,
time_freq_dim: int,
time_proj_dim: int,
text_embed_dim: int,
image_embed_dim: Optional[int] = None,
pos_embed_seq_len: Optional[int] = None,
):
super().__init__()
self.timesteps_proj = Timesteps(
num_channels=time_freq_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
self.time_embedder = TimestepEmbedding(
in_channels=time_freq_dim, time_embed_dim=dim)
self.act_fn = nn.SiLU()
self.time_proj = nn.Linear(dim, time_proj_dim)
self.text_embedder = PixArtAlphaTextProjection(
text_embed_dim, dim, act_fn="gelu_tanh")
def forward(
self,
timestep: torch.Tensor,
encoder_hidden_states: torch.Tensor,
):
timestep = self.timesteps_proj(timestep)
time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype
if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8:
timestep = timestep.to(time_embedder_dtype)
temb = self.time_embedder(timestep).type_as(encoder_hidden_states)
timestep_proj = self.time_proj(self.act_fn(temb))
encoder_hidden_states = self.text_embedder(encoder_hidden_states)
return temb, timestep_proj, encoder_hidden_states
class Transformer3DModel(ModelMixin, ConfigMixin):
_fsdp_shard_conditions: list = [
lambda name, module: isinstance(module, (MMDoubleStreamBlock))]
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
args: Any,
patch_size: list = [1, 2, 2],
in_channels: int = 4, # Should be VAE.config.latent_channels.
out_channels: int = None,
hidden_size: int = 3072,
heads_num: int = 24,
text_states_dim: int = 4096,
mlp_width_ratio: float = 4.0,
mm_double_blocks_depth: int = 20,
rope_dim_list: List[int] = [16, 56, 56],
rope_type: str = 'rope',
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
dit_modulation_type: str = "wanx",
attn_backend: str = 'auto',
theta: int = 256,
):
self.args = args
self.out_channels = out_channels or in_channels
self.patch_size = patch_size
self.hidden_size = hidden_size
self.heads_num = heads_num
self.rope_dim_list = rope_dim_list
self.dit_modulation_type = dit_modulation_type
self.mm_double_blocks_depth = mm_double_blocks_depth
self.attn_backend = get_preferred_attention_backend() if attn_backend == 'auto' else attn_backend
self.rope_type = rope_type
self.theta = theta
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.hidden_size = hidden_size
if hidden_size % heads_num != 0:
raise ValueError(
f"Hidden size {hidden_size} must be divisible by heads_num {heads_num}"
)
# image projection
self.img_in = nn.Conv3d(
in_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
# condition embedding
self.condition_embedder = WanTimeTextImageEmbedding(
dim=hidden_size,
time_freq_dim=256,
time_proj_dim=hidden_size * 6,
text_embed_dim=text_states_dim,
)
# double blocks
self.double_blocks = nn.ModuleList(
[
MMDoubleStreamBlock(
self.hidden_size,
self.heads_num,
mlp_width_ratio=mlp_width_ratio,
dit_modulation_type=self.dit_modulation_type,
attn_backend=attn_backend,
**factory_kwargs,
)
for _ in range(mm_double_blocks_depth)
]
)
# Output norm & projection
self.norm_out = nn.LayerNorm(
hidden_size, elementwise_affine=False, eps=1e-6
)
self.proj_out = nn.Linear(
hidden_size, out_channels * math.prod(patch_size),
**factory_kwargs)
def get_rotary_pos_embed(self, vis_rope_size, txt_rope_size=None):
target_ndim = 3
ndim = 5 - 2
if len(vis_rope_size) != target_ndim:
vis_rope_size = [1] * (target_ndim - len(vis_rope_size)
) + vis_rope_size # time axis
head_dim = self.hidden_size // self.heads_num
rope_dim_list = self.rope_dim_list
if rope_dim_list is None:
rope_dim_list = [head_dim //
target_ndim for _ in range(target_ndim)]
assert (
sum(rope_dim_list) == head_dim
), "sum(rope_dim_list) should equal to head_dim of attention layer"
vis_freqs, txt_freqs = get_nd_rotary_pos_embed(
rope_dim_list,
vis_rope_size,
txt_rope_size=txt_rope_size,
theta=self.theta,
use_real=True,
theta_rescale_factor=1,
)
return vis_freqs, txt_freqs
def forward(
self,
hidden_states: torch.Tensor,
timestep: torch.Tensor, # Should be in range(0, 1000).
encoder_hidden_states: torch.Tensor = None,
encoder_hidden_states_mask: torch.Tensor = None,
return_dict: bool = True,
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
# For Multi-item Input: hidden_states: (b, n, c, t, h, w)
# Permute the items into the temporal dimension
is_multi_item = (len(hidden_states.shape) == 6)
num_items = 0
if is_multi_item:
num_items = hidden_states.shape[1]
if num_items > 1:
assert self.patch_size[0] == 1, "For multi-item input, patch_size[0] must be 1"
# Move the last item to the first position
hidden_states = torch.cat(
[
hidden_states[:, -1:],
hidden_states[:, :-1]
],
dim=1
)
hidden_states = rearrange(
hidden_states, 'b n c t h w -> b c (n t) h w')
out = {}
batch_size, _, ot, oh, ow = hidden_states.shape
tt, th, tw = (
ot // self.patch_size[0],
oh // self.patch_size[1],
ow // self.patch_size[2],
)
# Text Mask
if encoder_hidden_states_mask == None:
encoder_hidden_states_mask = torch.ones(
(encoder_hidden_states.shape[0], encoder_hidden_states.shape[1]), dtype=torch.bool).to(encoder_hidden_states.device)
# Prepare img, txt, vec.
img = self.img_in(hidden_states).flatten(2).transpose(1, 2)
temb, vec, txt = self.condition_embedder(
timestep, encoder_hidden_states)
if vec.shape[-1] > self.hidden_size:
vec = vec.unflatten(1, (6, -1))
txt_seq_len = txt.shape[1]
img_seq_len = img.shape[1]
# rope
vis_freqs_cis, txt_freqs_cis = self.get_rotary_pos_embed(vis_rope_size=(
tt, th, tw), txt_rope_size=txt_seq_len if self.rope_type == 'mrope' else None)
# Compute attn_kwargs
attn_kwargs = {'thw': [tt, th, tw], 'txt_len': txt_seq_len}
if self.attn_backend == 'flash_attn':
cu_seqlens_q = get_cu_seqlens(
encoder_hidden_states_mask, img_seq_len)
cu_seqlens_kv = cu_seqlens_q
max_seqlen_q = img_seq_len + txt_seq_len
max_seqlen_kv = max_seqlen_q
attn_kwargs.update({
'cu_seqlens_q': cu_seqlens_q,
'cu_seqlens_kv': cu_seqlens_kv,
'max_seqlen_q': max_seqlen_q,
'max_seqlen_kv': max_seqlen_kv,
})
# --------------------- Pass through DiT blocks ------------------------
for _, block in enumerate(self.double_blocks):
double_block_args = [
img,
txt,
vec,
vis_freqs_cis,
txt_freqs_cis,
attn_kwargs
]
img, txt = block(*double_block_args)
img_len = img.shape[1]
x = torch.cat((img, txt), 1)
img = x[:, :img_len, ...]
# ---------------------------- Final layer ------------------------------
img = self.proj_out(self.norm_out(img))
img = self.unpatchify(img, tt, th, tw)
# Reshape back to multiple items
if is_multi_item:
img = rearrange(
img, 'b c (n t) h w -> b n c t h w', n=num_items)
if num_items > 1:
# Move the first item back to the last position
img = torch.cat(
[
img[:, 1:],
img[:, :1]
],
dim=1
)
return (img, txt)
def unpatchify(self, x, t, h, w):
"""
x: (N, T, patch_size**2 * C)
imgs: (N, H, W, C)
"""
c = self.out_channels
pt, ph, pw = self.patch_size
assert t * h * w == x.shape[1]
x = x.reshape(shape=(x.shape[0], t, h, w, pt, ph, pw, c))
x = torch.einsum("nthwopqc->nctohpwq", x)
imgs = x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw))
return imgs