43 / Meissonic /src /transformer_video.py
BryanW's picture
Upload folder using huggingface_hub
3d1c0e1 verified
import math
from typing import Optional
import torch
import torch.nn as nn
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.modeling_utils import ModelMixin
# Global debug flag - set to False to disable debug prints
DEBUG_TRANSFORMER = False
# from .attention import flash_attention
import torch
try:
import flash_attn_interface
FLASH_ATTN_3_AVAILABLE = True
except ModuleNotFoundError:
FLASH_ATTN_3_AVAILABLE = False
try:
import flash_attn
FLASH_ATTN_2_AVAILABLE = True
except ModuleNotFoundError:
FLASH_ATTN_2_AVAILABLE = False
import warnings
__all__ = [
'flash_attention',
'attention',
]
def flash_attention(
q,
k,
v,
q_lens=None,
k_lens=None,
dropout_p=0.,
softmax_scale=None,
q_scale=None,
causal=False,
window_size=(-1, -1),
deterministic=False,
dtype=torch.bfloat16,
version=None,
):
"""
q: [B, Lq, Nq, C1].
k: [B, Lk, Nk, C1].
v: [B, Lk, Nk, C2]. Nq must be divisible by Nk.
q_lens: [B].
k_lens: [B].
dropout_p: float. Dropout probability.
softmax_scale: float. The scaling of QK^T before applying softmax.
causal: bool. Whether to apply causal attention mask.
window_size: (left right). If not (-1, -1), apply sliding window local attention.
deterministic: bool. If True, slightly slower and uses more memory.
dtype: torch.dtype. Apply when dtype of q/k/v is not float16/bfloat16.
"""
half_dtypes = (torch.float16, torch.bfloat16)
assert dtype in half_dtypes
assert q.device.type == 'cuda' and q.size(-1) <= 256
# params
b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype
def half(x):
return x if x.dtype in half_dtypes else x.to(dtype)
# preprocess query
if q_lens is None:
q = half(q.flatten(0, 1))
q_lens = torch.tensor(
[lq] * b, dtype=torch.int32).to(
device=q.device, non_blocking=True)
else:
q = half(torch.cat([u[:v] for u, v in zip(q, q_lens)]))
# preprocess key, value
if k_lens is None:
k = half(k.flatten(0, 1))
v = half(v.flatten(0, 1))
k_lens = torch.tensor(
[lk] * b, dtype=torch.int32).to(
device=k.device, non_blocking=True)
else:
k = half(torch.cat([u[:v] for u, v in zip(k, k_lens)]))
v = half(torch.cat([u[:v] for u, v in zip(v, k_lens)]))
q = q.to(v.dtype)
k = k.to(v.dtype)
if q_scale is not None:
q = q * q_scale
if version is not None and version == 3 and not FLASH_ATTN_3_AVAILABLE:
warnings.warn(
'Flash attention 3 is not available, use flash attention 2 instead.'
)
# apply attention
if (version is None or version == 3) and FLASH_ATTN_3_AVAILABLE:
# Note: dropout_p, window_size are not supported in FA3 now.
x = flash_attn_interface.flash_attn_varlen_func(
q=q,
k=k,
v=v,
cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(
0, dtype=torch.int32).to(q.device, non_blocking=True),
cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(
0, dtype=torch.int32).to(q.device, non_blocking=True),
seqused_q=None,
seqused_k=None,
max_seqlen_q=lq,
max_seqlen_k=lk,
softmax_scale=softmax_scale,
causal=causal,
deterministic=deterministic)[0].unflatten(0, (b, lq))
else:
assert FLASH_ATTN_2_AVAILABLE
x = flash_attn.flash_attn_varlen_func(
q=q,
k=k,
v=v,
cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(
0, dtype=torch.int32).to(q.device, non_blocking=True),
cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(
0, dtype=torch.int32).to(q.device, non_blocking=True),
max_seqlen_q=lq,
max_seqlen_k=lk,
dropout_p=dropout_p,
softmax_scale=softmax_scale,
causal=causal,
window_size=window_size,
deterministic=deterministic).unflatten(0, (b, lq))
# output
return x.type(out_dtype)
def attention(
q,
k,
v,
q_lens=None,
k_lens=None,
dropout_p=0.,
softmax_scale=None,
q_scale=None,
causal=False,
window_size=(-1, -1),
deterministic=False,
dtype=torch.bfloat16,
fa_version=None,
):
if FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE:
return flash_attention(
q=q,
k=k,
v=v,
q_lens=q_lens,
k_lens=k_lens,
dropout_p=dropout_p,
softmax_scale=softmax_scale,
q_scale=q_scale,
causal=causal,
window_size=window_size,
deterministic=deterministic,
dtype=dtype,
version=fa_version,
)
else:
if q_lens is not None or k_lens is not None:
warnings.warn(
'Padding mask is disabled when using scaled_dot_product_attention. It can have a significant impact on performance.'
)
attn_mask = None
q = q.transpose(1, 2).to(dtype)
k = k.transpose(1, 2).to(dtype)
v = v.transpose(1, 2).to(dtype)
out = torch.nn.functional.scaled_dot_product_attention(
q, k, v, attn_mask=attn_mask, is_causal=causal, dropout_p=dropout_p)
out = out.transpose(1, 2).contiguous()
return out
__all__ = ['WanModel']
def sinusoidal_embedding_1d(dim, position):
# preprocess
assert dim % 2 == 0
half = dim // 2
# Ensure position is on CPU for float64 computation to avoid CUDA issues
# Convert to float64 for precision, then move back to original device
device = position.device
position = position.to(torch.float64)
# calculation
# Create range tensor on same device as position
arange_tensor = torch.arange(half, dtype=torch.float64, device=device)
sinusoid = torch.outer(
position, torch.pow(10000, -arange_tensor.div(half)))
x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
return x
@torch.amp.autocast('cuda', enabled=False)
def rope_params(max_seq_len, dim, theta=10000):
assert dim % 2 == 0
freqs = torch.outer(
torch.arange(max_seq_len),
1.0 / torch.pow(theta,
torch.arange(0, dim, 2).to(torch.float64).div(dim)))
freqs = torch.polar(torch.ones_like(freqs), freqs)
return freqs
@torch.amp.autocast('cuda', enabled=False)
def rope_apply(x, grid_sizes, freqs):
n, c = x.size(2), x.size(3) // 2
# Save original dtype to restore it later
original_dtype = x.dtype
# split freqs
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
# loop over samples
output = []
for i, (f, h, w) in enumerate(grid_sizes.tolist()):
seq_len = f * h * w
# precompute multipliers
x_i = torch.view_as_complex(x[i, :seq_len].to(torch.float64).reshape(
seq_len, n, -1, 2))
freqs_i = torch.cat([
freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
],
dim=-1).reshape(seq_len, 1, -1)
# apply rotary embedding
x_i = torch.view_as_real(x_i * freqs_i).flatten(2)
# Convert back to original dtype before concatenating
x_i = x_i.to(dtype=original_dtype)
# Handle the remaining part of the sequence
x_remaining = x[i, seq_len:]
if x_remaining.numel() > 0:
x_i = torch.cat([x_i, x_remaining])
else:
x_i = x_i
# append to collection
output.append(x_i)
# Stack and ensure dtype matches original input
return torch.stack(output).to(dtype=original_dtype)
class WanRMSNorm(nn.Module):
def __init__(self, dim, eps=1e-5):
super().__init__()
self.dim = dim
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x):
r"""
Args:
x(Tensor): Shape [B, L, C]
"""
# Ensure weight dtype matches input dtype
return self._norm(x.float()).type_as(x) * self.weight.type_as(x)
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
class WanLayerNorm(nn.LayerNorm):
def __init__(self, dim, eps=1e-6, elementwise_affine=False):
super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps)
def forward(self, x):
r"""
Args:
x(Tensor): Shape [B, L, C]
"""
# Convert to float32 for numerical stability, ensuring weights match input dtype
original_dtype = x.dtype
x_float = x.float()
if self.elementwise_affine:
weight_float = self.weight.float() if self.weight is not None else None
bias_float = self.bias.float() if self.bias is not None else None
# Use torch.nn.functional.layer_norm directly with converted weights
result = torch.nn.functional.layer_norm(x_float, self.normalized_shape, weight_float, bias_float, self.eps)
else:
result = super().forward(x_float)
return result.to(dtype=original_dtype)
class WanSelfAttention(nn.Module):
def __init__(self,
dim,
num_heads,
window_size=(-1, -1),
qk_norm=True,
eps=1e-6):
assert dim % num_heads == 0
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.window_size = window_size
self.qk_norm = qk_norm
self.eps = eps
# layers
self.q = nn.Linear(dim, dim)
self.k = nn.Linear(dim, dim)
self.v = nn.Linear(dim, dim)
self.o = nn.Linear(dim, dim)
self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
def forward(self, x, seq_lens, grid_sizes, freqs):
r"""
Args:
x(Tensor): Shape [B, L, num_heads, C / num_heads]
seq_lens(Tensor): Shape [B]
grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
"""
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
# query, key, value function
def qkv_fn(x):
q = self.norm_q(self.q(x)).view(b, s, n, d)
k = self.norm_k(self.k(x)).view(b, s, n, d)
v = self.v(x).view(b, s, n, d)
return q, k, v
q, k, v = qkv_fn(x)
# Save input dtype to ensure output matches
input_dtype = x.dtype
x = flash_attention(
q=rope_apply(q, grid_sizes, freqs),
k=rope_apply(k, grid_sizes, freqs),
v=v,
k_lens=seq_lens,
window_size=self.window_size)
# Ensure output dtype matches input dtype (in case rope_apply or flash_attention changed it)
x = x.to(dtype=input_dtype)
# output
x = x.flatten(2)
x = self.o(x)
return x
class WanCrossAttention(WanSelfAttention):
def forward(self, x, context, context_lens):
r"""
Args:
x(Tensor): Shape [B, L1, C]
context(Tensor): Shape [B, L2, C]
context_lens(Tensor): Shape [B]
"""
b, n, d = x.size(0), self.num_heads, self.head_dim
# Save input dtype to ensure output matches
input_dtype = x.dtype
# compute query, key, value
q = self.norm_q(self.q(x)).view(b, -1, n, d)
k = self.norm_k(self.k(context)).view(b, -1, n, d)
v = self.v(context).view(b, -1, n, d)
# compute attention
x = flash_attention(q, k, v, k_lens=context_lens)
# Ensure output dtype matches input dtype
x = x.to(dtype=input_dtype)
# output
x = x.flatten(2)
x = self.o(x)
return x
class WanAttentionBlock(nn.Module):
def __init__(self,
dim,
ffn_dim,
num_heads,
window_size=(-1, -1),
qk_norm=True,
cross_attn_norm=False,
eps=1e-6):
super().__init__()
self.dim = dim
self.ffn_dim = ffn_dim
self.num_heads = num_heads
self.window_size = window_size
self.qk_norm = qk_norm
self.cross_attn_norm = cross_attn_norm
self.eps = eps
# layers
self.norm1 = WanLayerNorm(dim, eps)
self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm,
eps)
self.norm3 = WanLayerNorm(
dim, eps,
elementwise_affine=True) if cross_attn_norm else nn.Identity()
self.cross_attn = WanCrossAttention(dim, num_heads, (-1, -1), qk_norm,
eps)
self.norm2 = WanLayerNorm(dim, eps)
self.ffn = nn.Sequential(
nn.Linear(dim, ffn_dim), nn.GELU(approximate='tanh'),
nn.Linear(ffn_dim, dim))
# modulation
self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
def forward(
self,
x,
e,
seq_lens,
grid_sizes,
freqs,
context,
context_lens,
):
r"""
Args:
x(Tensor): Shape [B, L, C]
e(Tensor): Shape [B, L1, 6, C]
seq_lens(Tensor): Shape [B], length of each sequence in batch
grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
"""
# Convert e to float32 for modulation computation (modulation expects float32)
e_float32 = e.to(dtype=torch.float32) if e.dtype != torch.float32 else e
with torch.amp.autocast('cuda', dtype=torch.float32):
e = (self.modulation.unsqueeze(0) + e_float32).chunk(6, dim=2)
assert e[0].dtype == torch.float32
# self-attention
# Ensure input dtype matches model weights (convert e to match x's dtype)
x_dtype = x.dtype
e_0 = e[0].squeeze(2).to(dtype=x_dtype)
e_1 = e[1].squeeze(2).to(dtype=x_dtype)
e_2 = e[2].squeeze(2).to(dtype=x_dtype)
attn_input = self.norm1(x) * (1 + e_1) + e_0
y = self.self_attn(attn_input, seq_lens, grid_sizes, freqs)
# Ensure dtype consistency: y and e_2 should match x's dtype
x = x + (y * e_2).to(dtype=x_dtype)
# cross-attention & ffn function
def cross_attn_ffn(x, context, context_lens, e):
x = x + self.cross_attn(self.norm3(x), context, context_lens)
# Ensure dtype consistency for FFN input
x_dtype = x.dtype
e_3 = e[3].squeeze(2).to(dtype=x_dtype)
e_4 = e[4].squeeze(2).to(dtype=x_dtype)
e_5 = e[5].squeeze(2).to(dtype=x_dtype)
ffn_input = self.norm2(x) * (1 + e_4) + e_3
y = self.ffn(ffn_input)
# Ensure dtype consistency: y and e_5 should match x's dtype
x = x + (y * e_5).to(dtype=x_dtype)
return x
x = cross_attn_ffn(x, context, context_lens, e)
return x
class Head(nn.Module):
def __init__(self, dim, out_dim, patch_size, eps=1e-6):
super().__init__()
self.dim = dim
self.out_dim = out_dim
self.patch_size = patch_size
self.eps = eps
# layers
out_dim = math.prod(patch_size) * out_dim
self.norm = WanLayerNorm(dim, eps)
self.head = nn.Linear(dim, out_dim)
# modulation
self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)
def forward(self, x, e):
r"""
Args:
x(Tensor): Shape [B, L1, C]
e(Tensor): Shape [B, L1, C]
"""
# Convert e to float32 for modulation computation (modulation expects float32)
e_float32 = e.to(dtype=torch.float32) if e.dtype != torch.float32 else e
with torch.amp.autocast('cuda', dtype=torch.float32):
e = (self.modulation.unsqueeze(0) + e_float32.unsqueeze(2)).chunk(2, dim=2)
# Ensure dtype consistency: convert e to match x's dtype
x_dtype = x.dtype
e_0 = e[0].squeeze(2).to(dtype=x_dtype)
e_1 = e[1].squeeze(2).to(dtype=x_dtype)
head_input = self.norm(x) * (1 + e_1) + e_0
x = self.head(head_input)
return x
class WanModel(ModelMixin, ConfigMixin):
r"""
Wan diffusion backbone supporting both text-to-video and image-to-video.
"""
ignore_for_config = [
'patch_size', 'cross_attn_norm', 'qk_norm', 'text_dim', 'window_size'
]
_no_split_modules = ['WanAttentionBlock']
@register_to_config
def __init__(self,
model_type='t2v',
patch_size=(1, 2, 2),
text_len=512,
in_dim=16,
dim=2048,
ffn_dim=8192,
freq_dim=256,
text_dim=4096,
out_dim=16,
num_heads=16,
num_layers=32,
window_size=(-1, -1),
qk_norm=True,
cross_attn_norm=True,
eps=1e-6):
r"""
Initialize the diffusion model backbone.
Args:
model_type (`str`, *optional*, defaults to 't2v'):
Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video)
patch_size (`tuple`, *optional*, defaults to (1, 2, 2)):
3D patch dimensions for video embedding (t_patch, h_patch, w_patch)
text_len (`int`, *optional*, defaults to 512):
Fixed length for text embeddings
in_dim (`int`, *optional*, defaults to 16):
Input video channels (C_in)
dim (`int`, *optional*, defaults to 2048):
Hidden dimension of the transformer
ffn_dim (`int`, *optional*, defaults to 8192):
Intermediate dimension in feed-forward network
freq_dim (`int`, *optional*, defaults to 256):
Dimension for sinusoidal time embeddings
text_dim (`int`, *optional*, defaults to 4096):
Input dimension for text embeddings
out_dim (`int`, *optional*, defaults to 16):
Output video channels (C_out)
num_heads (`int`, *optional*, defaults to 16):
Number of attention heads
num_layers (`int`, *optional*, defaults to 32):
Number of transformer blocks
window_size (`tuple`, *optional*, defaults to (-1, -1)):
Window size for local attention (-1 indicates global attention)
qk_norm (`bool`, *optional*, defaults to True):
Enable query/key normalization
cross_attn_norm (`bool`, *optional*, defaults to False):
Enable cross-attention normalization
eps (`float`, *optional*, defaults to 1e-6):
Epsilon value for normalization layers
"""
super().__init__()
assert model_type in ['t2v', 'i2v', 'ti2v', 's2v']
self.model_type = model_type
self.patch_size = patch_size
self.text_len = text_len
self.in_dim = in_dim
self.dim = dim
self.ffn_dim = ffn_dim
self.freq_dim = freq_dim
self.text_dim = text_dim
self.out_dim = out_dim
self.num_heads = num_heads
self.num_layers = num_layers
self.window_size = window_size
self.qk_norm = qk_norm
self.cross_attn_norm = cross_attn_norm
self.eps = eps
# embeddings
self.patch_embedding = nn.Conv3d(
in_dim, dim, kernel_size=patch_size, stride=patch_size)
self.text_embedding = nn.Sequential(
nn.Linear(text_dim, dim), nn.GELU(approximate='tanh'),
nn.Linear(dim, dim))
self.time_embedding = nn.Sequential(
nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6))
# blocks
self.blocks = nn.ModuleList([
WanAttentionBlock(dim, ffn_dim, num_heads, window_size, qk_norm,
cross_attn_norm, eps) for _ in range(num_layers)
])
# head
self.head = Head(dim, out_dim, patch_size, eps)
# buffers (don't use register_buffer otherwise dtype will be changed in to())
assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
d = dim // num_heads
self.freqs = torch.cat([
rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
rope_params(1024, 2 * (d // 6))
],
dim=1)
# initialize weights
self.init_weights()
def forward(
self,
x,
t,
context,
seq_len,
y=None,
):
r"""
Forward pass through the diffusion model
Args:
x (List[Tensor]):
List of input video tensors, each with shape [C_in, F, H, W]
t (Tensor):
Diffusion timesteps tensor of shape [B]
context (List[Tensor]):
List of text embeddings each with shape [L, C]
seq_len (`int`):
Maximum sequence length for positional encoding
y (List[Tensor], *optional*):
Conditional video inputs for image-to-video mode, same shape as x
Returns:
List[Tensor]:
List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8]
"""
if self.model_type == 'i2v':
assert y is not None
# params
device = self.patch_embedding.weight.device
if self.freqs.device != device:
self.freqs = self.freqs.to(device)
if y is not None:
x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
# embeddings
# Ensure input dtype matches patch_embedding weight dtype
patch_weight_dtype = self.patch_embedding.weight.dtype
x = [self.patch_embedding(u.unsqueeze(0).to(dtype=patch_weight_dtype)) for u in x]
grid_sizes = torch.stack(
[torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
x = [u.flatten(2).transpose(1, 2) for u in x]
seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
assert seq_lens.max() <= seq_len
x = torch.cat([
torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
dim=1) for u in x
])
# time embeddings
if t.dim() == 1:
t = t.expand(t.size(0), seq_len)
with torch.amp.autocast('cuda', dtype=torch.float32):
bt = t.size(0)
t = t.flatten()
e = self.time_embedding(
sinusoidal_embedding_1d(self.freq_dim,
t).unflatten(0, (bt, seq_len)).float())
e0 = self.time_projection(e).unflatten(2, (6, self.dim))
assert e.dtype == torch.float32 and e0.dtype == torch.float32
# Keep e and e0 as float32 for modulation computation
# They will be converted to x.dtype inside WanAttentionBlock.forward and Head.forward when needed
# context
context_lens = None
# Ensure context input dtype matches text_embedding weight dtype
text_weight_dtype = self.text_embedding[0].weight.dtype
context = self.text_embedding(
torch.stack([
torch.cat(
[u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
for u in context
]).to(dtype=text_weight_dtype))
# arguments
kwargs = dict(
e=e0,
seq_lens=seq_lens,
grid_sizes=grid_sizes,
freqs=self.freqs,
context=context,
context_lens=context_lens)
for block in self.blocks:
x = block(x, **kwargs)
# head
x = self.head(x, e)
# unpatchify
x = self.unpatchify(x, grid_sizes)
return [u.float() for u in x]
def unpatchify(self, x, grid_sizes):
r"""
Reconstruct video tensors from patch embeddings.
Args:
x (List[Tensor]):
List of patchified features, each with shape [L, C_out * prod(patch_size)]
grid_sizes (Tensor):
Original spatial-temporal grid dimensions before patching,
shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches)
Returns:
List[Tensor]:
Reconstructed video tensors with shape [C_out, F, H / 8, W / 8]
"""
c = self.out_dim
out = []
for u, v in zip(x, grid_sizes.tolist()):
u = u[:math.prod(v)].view(*v, *self.patch_size, c)
u = torch.einsum('fhwpqrc->cfphqwr', u)
u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)])
out.append(u)
return out
def init_weights(self):
r"""
Initialize model parameters using Xavier initialization.
"""
# basic init
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.zeros_(m.bias)
# init embeddings
nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1))
for m in self.text_embedding.modules():
if isinstance(m, nn.Linear):
nn.init.normal_(m.weight, std=.02)
for m in self.time_embedding.modules():
if isinstance(m, nn.Linear):
nn.init.normal_(m.weight, std=.02)
# init output layer
nn.init.zeros_(self.head.head.weight)
class WanDiscreteVideoTransformer(ModelMixin, ConfigMixin):
r"""
Wrapper around :class:`WanModel` that makes it usable as a **discrete video diffusion backbone**.
The goals of this wrapper are:
- keep the inner :class:`WanModel` architecture and parameter names intact so that Wan-1.3B
weights can later be loaded directly into ``self.backbone``;
- expose a simpler interface that takes **discrete codebook indices** (from a 2D VQ-VAE on
pseudo-video) and returns **logits over the codebook** for each spatio‑temporal position.
Notes
-----
- This class does **not** try to be drop‑in compatible with Meissonic's 2D ``Transformer2DModel``.
It is a parallel, video‑oriented path that still follows the same *discrete diffusion* principle:
predict per‑token logits given masked tokens + text.
- Pseudo‑video is represented as a 4D integer tensor ``[B, F, H, W]`` of codebook indices.
How to get these tokens from the current 2D VQ-VAE (e.g. per‑frame encoding & stacking)
is left to the higher‑level training / pipeline code.
"""
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
# discrete codebook settings
codebook_size: int,
vocab_size: int,
# video layout
num_frames: int,
height: int,
width: int,
# Wan backbone hyper‑parameters (mirrors WanModel.__init__)
model_type: str = 't2v',
patch_size: tuple = (1, 2, 2),
text_len: int = 512,
in_dim: int = 16,
dim: int = 2048,
ffn_dim: int = 8192,
freq_dim: int = 256,
text_dim: int = 4096,
out_dim: int = 16,
num_heads: int = 16,
num_layers: int = 32,
window_size: tuple = (-1, -1),
qk_norm: bool = True,
cross_attn_norm: bool = True,
eps: float = 1e-6,
):
super().__init__()
# save a minimal set of attributes useful for downstream tooling
self.codebook_size = codebook_size
self.vocab_size = vocab_size
self.num_frames = num_frames
self.height = height
self.width = width
# 1) backbone: keep WanModel intact for future weight loading
self.backbone = WanModel(
model_type=model_type,
patch_size=patch_size,
text_len=text_len,
in_dim=in_dim,
dim=dim,
ffn_dim=ffn_dim,
freq_dim=freq_dim,
text_dim=text_dim,
out_dim=out_dim,
num_heads=num_heads,
num_layers=num_layers,
window_size=window_size,
qk_norm=qk_norm,
cross_attn_norm=cross_attn_norm,
eps=eps,
)
# 2) discrete token embedding -> continuous video volume
#
# Input: tokens [B, F, H, W] with values in [0, vocab_size) where:
# - [0, codebook_size-1] = actual Cosmos codes (direct mapping, no shift)
# - codebook_size = mask_token_id (reserved for masking)
# Output: list of length B with tensors [in_dim, F, H, W]
#
# We keep this outside the backbone so that loading official Wan 1.3B weights
# into self.backbone will still work without clashes.
# Note: vocab_size = codebook_size + 1 to accommodate mask_token_id = codebook_size
self.token_embedding = nn.Embedding(vocab_size, in_dim)
# 3) projection from continuous video output -> logits over codebook
#
# Backbone output: list of B tensors [out_dim, F, H', W']
# We map it with a 3D 1x1x1 conv to [vocab_size, F, H', W'].
# Note: vocab_size = codebook_size + 1, where codebook_size is reserved for mask_token_id
self.logits_head = nn.Conv3d(out_dim, vocab_size, kernel_size=1)
# Gradient checkpointing support
self.gradient_checkpointing = False
def _tokens_to_video(self, tokens: torch.LongTensor) -> list:
r"""
Convert discrete tokens ``[B, F, H, W]`` into a list of length ``B`` where each element
is a dense video tensor ``[in_dim, F, H, W]`` suitable for :class:`WanModel`.
Note:
This method now supports dynamic input dimensions. The num_frames, height, width
stored in config are used as defaults/for seq_len calculation, but inputs can
have different dimensions as long as they're valid.
"""
assert tokens.dim() == 4, f"expected [B, F, H, W] tokens, got {tokens.shape}"
# Dynamic dimensions - no strict dimension checks, WanModel handles variable sizes
# [B, F, H, W, in_dim]
# Ensure output dtype matches token_embedding weight dtype
x = self.token_embedding(tokens)
# Ensure dtype matches model's expected dtype (usually bfloat16 for mixed precision)
token_embedding_dtype = self.token_embedding.weight.dtype
x = x.to(dtype=token_embedding_dtype)
# [B, in_dim, F, H, W]
x = x.permute(0, 4, 1, 2, 3).contiguous()
# WanModel expects a list of [C_in, F, H, W]
return [x_i for x_i in x]
def _text_to_list(self, encoder_hidden_states: torch.Tensor) -> list:
r"""
Convert batched text embeddings ``[B, L, C]`` into the list-of-tensors format
expected by :class:`WanModel`.
"""
assert encoder_hidden_states.dim() == 3, (
f"expected encoder_hidden_states [B, L, C], got {encoder_hidden_states.shape}")
return [e for e in encoder_hidden_states]
def _set_gradient_checkpointing(self, enable=True, gradient_checkpointing_func=None):
"""Set gradient checkpointing for the module."""
self.gradient_checkpointing = enable
def forward(
self,
tokens: torch.LongTensor,
timesteps: torch.LongTensor,
encoder_hidden_states: torch.FloatTensor,
y: Optional[list] = None,
) -> torch.FloatTensor:
r"""
Forward pass of the **discrete video transformer**.
Args:
tokens (`torch.LongTensor` of shape `[B, F, H, W]`):
Discrete codebook indices (e.g. from a 2D VQ-VAE applied frame‑wise).
timesteps (`torch.LongTensor` of shape `[B]` or `[B, F * H * W]`):
Diffusion timestep(s), following the same semantics as Meissonic's scalar timesteps.
encoder_hidden_states (`torch.FloatTensor` of shape `[B, L, C_text]`):
Text embeddings (e.g. from CLIP). Each sample corresponds to one video.
y (`Optional[list]`):
Optional conditional video list passed to the underlying :class:`WanModel`
for i2v / ti2v / s2v variants. For now this is surfaced as a raw passthrough
and can be left as ``None`` for pure text‑to‑video.
Returns:
`torch.FloatTensor`:
Logits over the codebook of shape `[B, codebook_size, F, H_out, W_out]`, where
`(H_out, W_out)` depend on the Wan patch configuration. For the default
`patch_size=(1, 2, 2)` and input ``H=W=height``, we have
``H_out = height // 2`` and ``W_out = width // 2``.
"""
device = tokens.device
if DEBUG_TRANSFORMER:
print(f"[DEBUG-transformer] Input: tokens.shape={tokens.shape}, encoder_hidden_states.shape={encoder_hidden_states.shape}, timesteps.shape={timesteps.shape}")
x_list = self._tokens_to_video(tokens)
context_list = self._text_to_list(encoder_hidden_states)
if DEBUG_TRANSFORMER:
print(f"[DEBUG-transformer] After conversion: len(x_list)={len(x_list)}, len(context_list)={len(context_list)}")
if len(x_list) > 0:
print(f"[DEBUG-transformer] x_list[0].shape={x_list[0].shape}")
if len(context_list) > 0:
print(f"[DEBUG-transformer] context_list[0].shape={context_list[0].shape}")
# Calculate seq_len from actual input dimensions (supports dynamic sizes)
# tokens: [B, F, H, W] -> after patchification: seq_len = F * (H/p_h) * (W/p_w)
_, f_in, h_in, w_in = tokens.shape
h_patch = h_in // self.backbone.patch_size[1]
w_patch = w_in // self.backbone.patch_size[2]
seq_len = f_in * h_patch * w_patch
# Prepare timesteps in the exact shape WanModel.forward expects.
# Its current implementation assumes `t` is either [B, seq_len] or will be
# expanded from 1D; the 1D branch is slightly buggy for non-singleton dims,
# so we always give it a [B, seq_len] tensor here.
if timesteps.dim() == 1:
# [B] -> [B, 1] -> [B, seq_len] (broadcast along sequence)
t_model = timesteps.to(device).unsqueeze(1).expand(-1, seq_len)
elif timesteps.dim() == 2:
assert timesteps.size(1) == seq_len, (
f"Expected timesteps second dim == seq_len ({seq_len}), "
f"but got {timesteps.size(1)}"
)
t_model = timesteps.to(device)
else:
raise ValueError(
f"Unsupported timesteps shape {timesteps.shape}; "
"expected [B] or [B, seq_len]"
)
if DEBUG_TRANSFORMER:
print(f"[DEBUG-transformer] t_model.shape={t_model.shape}")
# WanModel.forward expects:
# x: List[Tensor [C_in, F, H, W]]
# t: Tensor [B] or [B, seq_len]
# context: List[Tensor [L, C_text]]
# seq_len: int
# y: Optional[List[Tensor]]
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
# Unpack inputs: x_list, t, context_list, seq_len, y
x_in, t_in, context_in, seq_len_in, y_in = inputs
return module(x=x_in, t=t_in, context=context_in, seq_len=seq_len_in, y=y_in)
return custom_forward
# Use gradient checkpointing for the backbone
ckpt_kwargs = {"use_reentrant": False}
out_list = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.backbone),
x_list,
t_model,
context_list,
seq_len,
y,
**ckpt_kwargs,
)
else:
out_list = self.backbone(
x=x_list,
t=t_model,
context=context_list,
seq_len=seq_len,
y=y,
)
if DEBUG_TRANSFORMER:
print(f"[DEBUG-transformer] After backbone: len(out_list)={len(out_list)}")
if len(out_list) > 0:
print(f"[DEBUG-transformer] out_list[0].shape={out_list[0].shape}")
# out_list: length B, each [C_out, F, H_out, W_out]
vids = torch.stack(out_list, dim=0) # [B, C_out, F, H_out, W_out]
if DEBUG_TRANSFORMER:
print(f"[DEBUG-transformer] After stack: vids.shape={vids.shape}")
# Ensure vids dtype matches logits_head weight dtype
vids = vids.to(dtype=self.logits_head.weight.dtype)
logits = self.logits_head(vids) # [B, vocab_size, F, H_out, W_out] where vocab_size = codebook_size + 1
if DEBUG_TRANSFORMER:
print(f"[DEBUG-transformer] Final logits.shape={logits.shape}")
return logits
# def _available_device():
# return "cuda" if torch.cuda.is_available() else "cpu"
# def test_wan_discrete_video_transformer_forward_and_shapes():
# """
# Basic smoke test:
# - build a tiny WanDiscreteVideoTransformer
# - run a forward pass with random pseudo-video tokens + random text
# - check output shapes, parameter count and (if CUDA present) memory usage
# """
# device = _available_device()
# # small config to keep the test lightweight
# codebook_size = 128
# vocab_size = codebook_size + 1 # reserve one for mask if needed later
# num_frames = 2
# height = 16
# width = 16
# model = WanDiscreteVideoTransformer(
# codebook_size=codebook_size,
# vocab_size=vocab_size,
# num_frames=num_frames,
# height=height,
# width=width,
# # shrink Wan backbone for the unit test
# in_dim=32,
# dim=64,
# ffn_dim=128,
# freq_dim=32,
# text_dim=64,
# out_dim=32,
# num_heads=4,
# num_layers=2,
# ).to(device)
# model.eval()
# batch_size = 2
# # pseudo-video tokens from 2D VQ-VAE on frames: [B, F, H, W]
# tokens = torch.randint(
# low=0,
# high=codebook_size,
# size=(batch_size, num_frames, height, width),
# dtype=torch.long,
# device=device,
# )
# # text: [B, L, C_text]
# text_seq_len = 8
# encoder_hidden_states = torch.randn(
# batch_size, text_seq_len, model.backbone.text_dim, device=device
# )
# # timesteps: [B]
# timesteps = torch.randint(
# low=0, high=1000, size=(batch_size,), dtype=torch.long, device=device
# )
# # track memory if CUDA is available
# if device == "cuda":
# torch.cuda.reset_peak_memory_stats()
# mem_before = torch.cuda.memory_allocated()
# else:
# mem_before = 0
# with torch.no_grad():
# logits = model(
# tokens=tokens,
# timesteps=timesteps,
# encoder_hidden_states=encoder_hidden_states,
# y=None,
# )
# if device == "cuda":
# mem_after = torch.cuda.memory_allocated()
# peak_mem = torch.cuda.max_memory_allocated()
# else:
# mem_after = mem_before
# peak_mem = mem_before
# # logits: [B, codebook_size, F, H_out, W_out]
# assert logits.shape[0] == batch_size
# assert logits.shape[1] == codebook_size
# assert logits.shape[2] == num_frames
# # WanModel returns unpatchified videos, so spatial size matches the input grid.
# h_out = height
# w_out = width
# assert logits.shape[3] == h_out
# assert logits.shape[4] == w_out
# # parameter count sanity check (just ensure it's > 0 and finite)
# num_params = sum(p.numel() for p in model.parameters())
# assert num_params > 0
# assert math.isfinite(float(num_params))
# # memory sanity check (on CUDA the forward pass should allocate > 0 bytes)
# if device == "cuda":
# assert peak_mem >= mem_after >= mem_before
# import torch
# from safetensors import safe_open
# # from src.transformer_video import WanDiscreteVideoTransformer
# ckpt_path = "/mnt/Meissonic/model/diffusion_pytorch_model.safetensors"
# # 1) 按你想匹配 wan2.1 的超参实例化(这里写一份常用配置,务必与 ckpt 对齐)
# model = WanDiscreteVideoTransformer(
# codebook_size=128, # 离散侧自定义
# vocab_size=129,
# num_frames=2,
# height=16,
# width=16,
# # Wan backbone 超参需与 ckpt 完全一致
# model_type="t2v",
# patch_size=(1, 2, 2),
# in_dim=16,
# dim=1536,
# ffn_dim=8960,
# freq_dim=256,
# text_dim=4096,
# out_dim=16,
# num_heads=12,
# num_layers=30,
# window_size=(-1, -1),
# qk_norm=True,
# cross_attn_norm=True,
# eps=1e-6,
# )
# # 2) 读取 safetensors
# state_dict = {}
# with safe_open(ckpt_path, framework="pt", device="cpu") as f:
# for k in f.keys():
# state_dict[k] = f.get_tensor(k)
# # 3) 尝试加载到 backbone(不碰 token_embedding/logits_head)
# missing, unexpected = model.backbone.load_state_dict(state_dict, strict=False)
# print("Missing keys:", missing[:50], "... total", len(missing))
# print("Unexpected keys:", unexpected[:50], "... total", len(unexpected))
# print("Backbone params (M):", sum(p.numel() for p in model.backbone.parameters()) / 1e6)
# print("Params (M):", sum(p.numel() for p in model.parameters()) / 1e6)
# # if __name__ == '__main__':
# # # test_wan_discrete_video_transformer_forward_and_shapes()
# # print('WanDiscreteVideoTransformer forward pass test: PASSED')