import math from typing import Optional import torch import torch.nn as nn from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.models.modeling_utils import ModelMixin # Global debug flag - set to False to disable debug prints DEBUG_TRANSFORMER = False # from .attention import flash_attention import torch try: import flash_attn_interface FLASH_ATTN_3_AVAILABLE = True except ModuleNotFoundError: FLASH_ATTN_3_AVAILABLE = False try: import flash_attn FLASH_ATTN_2_AVAILABLE = True except ModuleNotFoundError: FLASH_ATTN_2_AVAILABLE = False import warnings __all__ = [ 'flash_attention', 'attention', ] def flash_attention( q, k, v, q_lens=None, k_lens=None, dropout_p=0., softmax_scale=None, q_scale=None, causal=False, window_size=(-1, -1), deterministic=False, dtype=torch.bfloat16, version=None, ): """ q: [B, Lq, Nq, C1]. k: [B, Lk, Nk, C1]. v: [B, Lk, Nk, C2]. Nq must be divisible by Nk. q_lens: [B]. k_lens: [B]. dropout_p: float. Dropout probability. softmax_scale: float. The scaling of QK^T before applying softmax. causal: bool. Whether to apply causal attention mask. window_size: (left right). If not (-1, -1), apply sliding window local attention. deterministic: bool. If True, slightly slower and uses more memory. dtype: torch.dtype. Apply when dtype of q/k/v is not float16/bfloat16. """ half_dtypes = (torch.float16, torch.bfloat16) assert dtype in half_dtypes assert q.device.type == 'cuda' and q.size(-1) <= 256 # params b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype def half(x): return x if x.dtype in half_dtypes else x.to(dtype) # preprocess query if q_lens is None: q = half(q.flatten(0, 1)) q_lens = torch.tensor( [lq] * b, dtype=torch.int32).to( device=q.device, non_blocking=True) else: q = half(torch.cat([u[:v] for u, v in zip(q, q_lens)])) # preprocess key, value if k_lens is None: k = half(k.flatten(0, 1)) v = half(v.flatten(0, 1)) k_lens = torch.tensor( [lk] * b, dtype=torch.int32).to( device=k.device, non_blocking=True) else: k = half(torch.cat([u[:v] for u, v in zip(k, k_lens)])) v = half(torch.cat([u[:v] for u, v in zip(v, k_lens)])) q = q.to(v.dtype) k = k.to(v.dtype) if q_scale is not None: q = q * q_scale if version is not None and version == 3 and not FLASH_ATTN_3_AVAILABLE: warnings.warn( 'Flash attention 3 is not available, use flash attention 2 instead.' ) # apply attention if (version is None or version == 3) and FLASH_ATTN_3_AVAILABLE: # Note: dropout_p, window_size are not supported in FA3 now. x = flash_attn_interface.flash_attn_varlen_func( q=q, k=k, v=v, cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum( 0, dtype=torch.int32).to(q.device, non_blocking=True), cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum( 0, dtype=torch.int32).to(q.device, non_blocking=True), seqused_q=None, seqused_k=None, max_seqlen_q=lq, max_seqlen_k=lk, softmax_scale=softmax_scale, causal=causal, deterministic=deterministic)[0].unflatten(0, (b, lq)) else: assert FLASH_ATTN_2_AVAILABLE x = flash_attn.flash_attn_varlen_func( q=q, k=k, v=v, cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum( 0, dtype=torch.int32).to(q.device, non_blocking=True), cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum( 0, dtype=torch.int32).to(q.device, non_blocking=True), max_seqlen_q=lq, max_seqlen_k=lk, dropout_p=dropout_p, softmax_scale=softmax_scale, causal=causal, window_size=window_size, deterministic=deterministic).unflatten(0, (b, lq)) # output return x.type(out_dtype) def attention( q, k, v, q_lens=None, k_lens=None, dropout_p=0., softmax_scale=None, q_scale=None, causal=False, window_size=(-1, -1), deterministic=False, dtype=torch.bfloat16, fa_version=None, ): if FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE: return flash_attention( q=q, k=k, v=v, q_lens=q_lens, k_lens=k_lens, dropout_p=dropout_p, softmax_scale=softmax_scale, q_scale=q_scale, causal=causal, window_size=window_size, deterministic=deterministic, dtype=dtype, version=fa_version, ) else: if q_lens is not None or k_lens is not None: warnings.warn( 'Padding mask is disabled when using scaled_dot_product_attention. It can have a significant impact on performance.' ) attn_mask = None q = q.transpose(1, 2).to(dtype) k = k.transpose(1, 2).to(dtype) v = v.transpose(1, 2).to(dtype) out = torch.nn.functional.scaled_dot_product_attention( q, k, v, attn_mask=attn_mask, is_causal=causal, dropout_p=dropout_p) out = out.transpose(1, 2).contiguous() return out __all__ = ['WanModel'] def sinusoidal_embedding_1d(dim, position): # preprocess assert dim % 2 == 0 half = dim // 2 # Ensure position is on CPU for float64 computation to avoid CUDA issues # Convert to float64 for precision, then move back to original device device = position.device position = position.to(torch.float64) # calculation # Create range tensor on same device as position arange_tensor = torch.arange(half, dtype=torch.float64, device=device) sinusoid = torch.outer( position, torch.pow(10000, -arange_tensor.div(half))) x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1) return x @torch.amp.autocast('cuda', enabled=False) def rope_params(max_seq_len, dim, theta=10000): assert dim % 2 == 0 freqs = torch.outer( torch.arange(max_seq_len), 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float64).div(dim))) freqs = torch.polar(torch.ones_like(freqs), freqs) return freqs @torch.amp.autocast('cuda', enabled=False) def rope_apply(x, grid_sizes, freqs): n, c = x.size(2), x.size(3) // 2 # Save original dtype to restore it later original_dtype = x.dtype # split freqs freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1) # loop over samples output = [] for i, (f, h, w) in enumerate(grid_sizes.tolist()): seq_len = f * h * w # precompute multipliers x_i = torch.view_as_complex(x[i, :seq_len].to(torch.float64).reshape( seq_len, n, -1, 2)) freqs_i = torch.cat([ freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1) ], dim=-1).reshape(seq_len, 1, -1) # apply rotary embedding x_i = torch.view_as_real(x_i * freqs_i).flatten(2) # Convert back to original dtype before concatenating x_i = x_i.to(dtype=original_dtype) # Handle the remaining part of the sequence x_remaining = x[i, seq_len:] if x_remaining.numel() > 0: x_i = torch.cat([x_i, x_remaining]) else: x_i = x_i # append to collection output.append(x_i) # Stack and ensure dtype matches original input return torch.stack(output).to(dtype=original_dtype) class WanRMSNorm(nn.Module): def __init__(self, dim, eps=1e-5): super().__init__() self.dim = dim self.eps = eps self.weight = nn.Parameter(torch.ones(dim)) def forward(self, x): r""" Args: x(Tensor): Shape [B, L, C] """ # Ensure weight dtype matches input dtype return self._norm(x.float()).type_as(x) * self.weight.type_as(x) def _norm(self, x): return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps) class WanLayerNorm(nn.LayerNorm): def __init__(self, dim, eps=1e-6, elementwise_affine=False): super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps) def forward(self, x): r""" Args: x(Tensor): Shape [B, L, C] """ # Convert to float32 for numerical stability, ensuring weights match input dtype original_dtype = x.dtype x_float = x.float() if self.elementwise_affine: weight_float = self.weight.float() if self.weight is not None else None bias_float = self.bias.float() if self.bias is not None else None # Use torch.nn.functional.layer_norm directly with converted weights result = torch.nn.functional.layer_norm(x_float, self.normalized_shape, weight_float, bias_float, self.eps) else: result = super().forward(x_float) return result.to(dtype=original_dtype) class WanSelfAttention(nn.Module): def __init__(self, dim, num_heads, window_size=(-1, -1), qk_norm=True, eps=1e-6): assert dim % num_heads == 0 super().__init__() self.dim = dim self.num_heads = num_heads self.head_dim = dim // num_heads self.window_size = window_size self.qk_norm = qk_norm self.eps = eps # layers self.q = nn.Linear(dim, dim) self.k = nn.Linear(dim, dim) self.v = nn.Linear(dim, dim) self.o = nn.Linear(dim, dim) self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity() self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity() def forward(self, x, seq_lens, grid_sizes, freqs): r""" Args: x(Tensor): Shape [B, L, num_heads, C / num_heads] seq_lens(Tensor): Shape [B] grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W) freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2] """ b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim # query, key, value function def qkv_fn(x): q = self.norm_q(self.q(x)).view(b, s, n, d) k = self.norm_k(self.k(x)).view(b, s, n, d) v = self.v(x).view(b, s, n, d) return q, k, v q, k, v = qkv_fn(x) # Save input dtype to ensure output matches input_dtype = x.dtype x = flash_attention( q=rope_apply(q, grid_sizes, freqs), k=rope_apply(k, grid_sizes, freqs), v=v, k_lens=seq_lens, window_size=self.window_size) # Ensure output dtype matches input dtype (in case rope_apply or flash_attention changed it) x = x.to(dtype=input_dtype) # output x = x.flatten(2) x = self.o(x) return x class WanCrossAttention(WanSelfAttention): def forward(self, x, context, context_lens): r""" Args: x(Tensor): Shape [B, L1, C] context(Tensor): Shape [B, L2, C] context_lens(Tensor): Shape [B] """ b, n, d = x.size(0), self.num_heads, self.head_dim # Save input dtype to ensure output matches input_dtype = x.dtype # compute query, key, value q = self.norm_q(self.q(x)).view(b, -1, n, d) k = self.norm_k(self.k(context)).view(b, -1, n, d) v = self.v(context).view(b, -1, n, d) # compute attention x = flash_attention(q, k, v, k_lens=context_lens) # Ensure output dtype matches input dtype x = x.to(dtype=input_dtype) # output x = x.flatten(2) x = self.o(x) return x class WanAttentionBlock(nn.Module): def __init__(self, dim, ffn_dim, num_heads, window_size=(-1, -1), qk_norm=True, cross_attn_norm=False, eps=1e-6): super().__init__() self.dim = dim self.ffn_dim = ffn_dim self.num_heads = num_heads self.window_size = window_size self.qk_norm = qk_norm self.cross_attn_norm = cross_attn_norm self.eps = eps # layers self.norm1 = WanLayerNorm(dim, eps) self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm, eps) self.norm3 = WanLayerNorm( dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity() self.cross_attn = WanCrossAttention(dim, num_heads, (-1, -1), qk_norm, eps) self.norm2 = WanLayerNorm(dim, eps) self.ffn = nn.Sequential( nn.Linear(dim, ffn_dim), nn.GELU(approximate='tanh'), nn.Linear(ffn_dim, dim)) # modulation self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5) def forward( self, x, e, seq_lens, grid_sizes, freqs, context, context_lens, ): r""" Args: x(Tensor): Shape [B, L, C] e(Tensor): Shape [B, L1, 6, C] seq_lens(Tensor): Shape [B], length of each sequence in batch grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W) freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2] """ # Convert e to float32 for modulation computation (modulation expects float32) e_float32 = e.to(dtype=torch.float32) if e.dtype != torch.float32 else e with torch.amp.autocast('cuda', dtype=torch.float32): e = (self.modulation.unsqueeze(0) + e_float32).chunk(6, dim=2) assert e[0].dtype == torch.float32 # self-attention # Ensure input dtype matches model weights (convert e to match x's dtype) x_dtype = x.dtype e_0 = e[0].squeeze(2).to(dtype=x_dtype) e_1 = e[1].squeeze(2).to(dtype=x_dtype) e_2 = e[2].squeeze(2).to(dtype=x_dtype) attn_input = self.norm1(x) * (1 + e_1) + e_0 y = self.self_attn(attn_input, seq_lens, grid_sizes, freqs) # Ensure dtype consistency: y and e_2 should match x's dtype x = x + (y * e_2).to(dtype=x_dtype) # cross-attention & ffn function def cross_attn_ffn(x, context, context_lens, e): x = x + self.cross_attn(self.norm3(x), context, context_lens) # Ensure dtype consistency for FFN input x_dtype = x.dtype e_3 = e[3].squeeze(2).to(dtype=x_dtype) e_4 = e[4].squeeze(2).to(dtype=x_dtype) e_5 = e[5].squeeze(2).to(dtype=x_dtype) ffn_input = self.norm2(x) * (1 + e_4) + e_3 y = self.ffn(ffn_input) # Ensure dtype consistency: y and e_5 should match x's dtype x = x + (y * e_5).to(dtype=x_dtype) return x x = cross_attn_ffn(x, context, context_lens, e) return x class Head(nn.Module): def __init__(self, dim, out_dim, patch_size, eps=1e-6): super().__init__() self.dim = dim self.out_dim = out_dim self.patch_size = patch_size self.eps = eps # layers out_dim = math.prod(patch_size) * out_dim self.norm = WanLayerNorm(dim, eps) self.head = nn.Linear(dim, out_dim) # modulation self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5) def forward(self, x, e): r""" Args: x(Tensor): Shape [B, L1, C] e(Tensor): Shape [B, L1, C] """ # Convert e to float32 for modulation computation (modulation expects float32) e_float32 = e.to(dtype=torch.float32) if e.dtype != torch.float32 else e with torch.amp.autocast('cuda', dtype=torch.float32): e = (self.modulation.unsqueeze(0) + e_float32.unsqueeze(2)).chunk(2, dim=2) # Ensure dtype consistency: convert e to match x's dtype x_dtype = x.dtype e_0 = e[0].squeeze(2).to(dtype=x_dtype) e_1 = e[1].squeeze(2).to(dtype=x_dtype) head_input = self.norm(x) * (1 + e_1) + e_0 x = self.head(head_input) return x class WanModel(ModelMixin, ConfigMixin): r""" Wan diffusion backbone supporting both text-to-video and image-to-video. """ ignore_for_config = [ 'patch_size', 'cross_attn_norm', 'qk_norm', 'text_dim', 'window_size' ] _no_split_modules = ['WanAttentionBlock'] @register_to_config def __init__(self, model_type='t2v', patch_size=(1, 2, 2), text_len=512, in_dim=16, dim=2048, ffn_dim=8192, freq_dim=256, text_dim=4096, out_dim=16, num_heads=16, num_layers=32, window_size=(-1, -1), qk_norm=True, cross_attn_norm=True, eps=1e-6): r""" Initialize the diffusion model backbone. Args: model_type (`str`, *optional*, defaults to 't2v'): Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video) patch_size (`tuple`, *optional*, defaults to (1, 2, 2)): 3D patch dimensions for video embedding (t_patch, h_patch, w_patch) text_len (`int`, *optional*, defaults to 512): Fixed length for text embeddings in_dim (`int`, *optional*, defaults to 16): Input video channels (C_in) dim (`int`, *optional*, defaults to 2048): Hidden dimension of the transformer ffn_dim (`int`, *optional*, defaults to 8192): Intermediate dimension in feed-forward network freq_dim (`int`, *optional*, defaults to 256): Dimension for sinusoidal time embeddings text_dim (`int`, *optional*, defaults to 4096): Input dimension for text embeddings out_dim (`int`, *optional*, defaults to 16): Output video channels (C_out) num_heads (`int`, *optional*, defaults to 16): Number of attention heads num_layers (`int`, *optional*, defaults to 32): Number of transformer blocks window_size (`tuple`, *optional*, defaults to (-1, -1)): Window size for local attention (-1 indicates global attention) qk_norm (`bool`, *optional*, defaults to True): Enable query/key normalization cross_attn_norm (`bool`, *optional*, defaults to False): Enable cross-attention normalization eps (`float`, *optional*, defaults to 1e-6): Epsilon value for normalization layers """ super().__init__() assert model_type in ['t2v', 'i2v', 'ti2v', 's2v'] self.model_type = model_type self.patch_size = patch_size self.text_len = text_len self.in_dim = in_dim self.dim = dim self.ffn_dim = ffn_dim self.freq_dim = freq_dim self.text_dim = text_dim self.out_dim = out_dim self.num_heads = num_heads self.num_layers = num_layers self.window_size = window_size self.qk_norm = qk_norm self.cross_attn_norm = cross_attn_norm self.eps = eps # embeddings self.patch_embedding = nn.Conv3d( in_dim, dim, kernel_size=patch_size, stride=patch_size) self.text_embedding = nn.Sequential( nn.Linear(text_dim, dim), nn.GELU(approximate='tanh'), nn.Linear(dim, dim)) self.time_embedding = nn.Sequential( nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim)) self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6)) # blocks self.blocks = nn.ModuleList([ WanAttentionBlock(dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps) for _ in range(num_layers) ]) # head self.head = Head(dim, out_dim, patch_size, eps) # buffers (don't use register_buffer otherwise dtype will be changed in to()) assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0 d = dim // num_heads self.freqs = torch.cat([ rope_params(1024, d - 4 * (d // 6)), rope_params(1024, 2 * (d // 6)), rope_params(1024, 2 * (d // 6)) ], dim=1) # initialize weights self.init_weights() def forward( self, x, t, context, seq_len, y=None, ): r""" Forward pass through the diffusion model Args: x (List[Tensor]): List of input video tensors, each with shape [C_in, F, H, W] t (Tensor): Diffusion timesteps tensor of shape [B] context (List[Tensor]): List of text embeddings each with shape [L, C] seq_len (`int`): Maximum sequence length for positional encoding y (List[Tensor], *optional*): Conditional video inputs for image-to-video mode, same shape as x Returns: List[Tensor]: List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8] """ if self.model_type == 'i2v': assert y is not None # params device = self.patch_embedding.weight.device if self.freqs.device != device: self.freqs = self.freqs.to(device) if y is not None: x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)] # embeddings # Ensure input dtype matches patch_embedding weight dtype patch_weight_dtype = self.patch_embedding.weight.dtype x = [self.patch_embedding(u.unsqueeze(0).to(dtype=patch_weight_dtype)) for u in x] grid_sizes = torch.stack( [torch.tensor(u.shape[2:], dtype=torch.long) for u in x]) x = [u.flatten(2).transpose(1, 2) for u in x] seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long) assert seq_lens.max() <= seq_len x = torch.cat([ torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], dim=1) for u in x ]) # time embeddings if t.dim() == 1: t = t.expand(t.size(0), seq_len) with torch.amp.autocast('cuda', dtype=torch.float32): bt = t.size(0) t = t.flatten() e = self.time_embedding( sinusoidal_embedding_1d(self.freq_dim, t).unflatten(0, (bt, seq_len)).float()) e0 = self.time_projection(e).unflatten(2, (6, self.dim)) assert e.dtype == torch.float32 and e0.dtype == torch.float32 # Keep e and e0 as float32 for modulation computation # They will be converted to x.dtype inside WanAttentionBlock.forward and Head.forward when needed # context context_lens = None # Ensure context input dtype matches text_embedding weight dtype text_weight_dtype = self.text_embedding[0].weight.dtype context = self.text_embedding( torch.stack([ torch.cat( [u, u.new_zeros(self.text_len - u.size(0), u.size(1))]) for u in context ]).to(dtype=text_weight_dtype)) # arguments kwargs = dict( e=e0, seq_lens=seq_lens, grid_sizes=grid_sizes, freqs=self.freqs, context=context, context_lens=context_lens) for block in self.blocks: x = block(x, **kwargs) # head x = self.head(x, e) # unpatchify x = self.unpatchify(x, grid_sizes) return [u.float() for u in x] def unpatchify(self, x, grid_sizes): r""" Reconstruct video tensors from patch embeddings. Args: x (List[Tensor]): List of patchified features, each with shape [L, C_out * prod(patch_size)] grid_sizes (Tensor): Original spatial-temporal grid dimensions before patching, shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches) Returns: List[Tensor]: Reconstructed video tensors with shape [C_out, F, H / 8, W / 8] """ c = self.out_dim out = [] for u, v in zip(x, grid_sizes.tolist()): u = u[:math.prod(v)].view(*v, *self.patch_size, c) u = torch.einsum('fhwpqrc->cfphqwr', u) u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)]) out.append(u) return out def init_weights(self): r""" Initialize model parameters using Xavier initialization. """ # basic init for m in self.modules(): if isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight) if m.bias is not None: nn.init.zeros_(m.bias) # init embeddings nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1)) for m in self.text_embedding.modules(): if isinstance(m, nn.Linear): nn.init.normal_(m.weight, std=.02) for m in self.time_embedding.modules(): if isinstance(m, nn.Linear): nn.init.normal_(m.weight, std=.02) # init output layer nn.init.zeros_(self.head.head.weight) class WanDiscreteVideoTransformer(ModelMixin, ConfigMixin): r""" Wrapper around :class:`WanModel` that makes it usable as a **discrete video diffusion backbone**. The goals of this wrapper are: - keep the inner :class:`WanModel` architecture and parameter names intact so that Wan-1.3B weights can later be loaded directly into ``self.backbone``; - expose a simpler interface that takes **discrete codebook indices** (from a 2D VQ-VAE on pseudo-video) and returns **logits over the codebook** for each spatio‑temporal position. Notes ----- - This class does **not** try to be drop‑in compatible with Meissonic's 2D ``Transformer2DModel``. It is a parallel, video‑oriented path that still follows the same *discrete diffusion* principle: predict per‑token logits given masked tokens + text. - Pseudo‑video is represented as a 4D integer tensor ``[B, F, H, W]`` of codebook indices. How to get these tokens from the current 2D VQ-VAE (e.g. per‑frame encoding & stacking) is left to the higher‑level training / pipeline code. """ _supports_gradient_checkpointing = True @register_to_config def __init__( self, # discrete codebook settings codebook_size: int, vocab_size: int, # video layout num_frames: int, height: int, width: int, # Wan backbone hyper‑parameters (mirrors WanModel.__init__) model_type: str = 't2v', patch_size: tuple = (1, 2, 2), text_len: int = 512, in_dim: int = 16, dim: int = 2048, ffn_dim: int = 8192, freq_dim: int = 256, text_dim: int = 4096, out_dim: int = 16, num_heads: int = 16, num_layers: int = 32, window_size: tuple = (-1, -1), qk_norm: bool = True, cross_attn_norm: bool = True, eps: float = 1e-6, ): super().__init__() # save a minimal set of attributes useful for downstream tooling self.codebook_size = codebook_size self.vocab_size = vocab_size self.num_frames = num_frames self.height = height self.width = width # 1) backbone: keep WanModel intact for future weight loading self.backbone = WanModel( model_type=model_type, patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, ) # 2) discrete token embedding -> continuous video volume # # Input: tokens [B, F, H, W] with values in [0, vocab_size) where: # - [0, codebook_size-1] = actual Cosmos codes (direct mapping, no shift) # - codebook_size = mask_token_id (reserved for masking) # Output: list of length B with tensors [in_dim, F, H, W] # # We keep this outside the backbone so that loading official Wan 1.3B weights # into self.backbone will still work without clashes. # Note: vocab_size = codebook_size + 1 to accommodate mask_token_id = codebook_size self.token_embedding = nn.Embedding(vocab_size, in_dim) # 3) projection from continuous video output -> logits over codebook # # Backbone output: list of B tensors [out_dim, F, H', W'] # We map it with a 3D 1x1x1 conv to [vocab_size, F, H', W']. # Note: vocab_size = codebook_size + 1, where codebook_size is reserved for mask_token_id self.logits_head = nn.Conv3d(out_dim, vocab_size, kernel_size=1) # Gradient checkpointing support self.gradient_checkpointing = False def _tokens_to_video(self, tokens: torch.LongTensor) -> list: r""" Convert discrete tokens ``[B, F, H, W]`` into a list of length ``B`` where each element is a dense video tensor ``[in_dim, F, H, W]`` suitable for :class:`WanModel`. Note: This method now supports dynamic input dimensions. The num_frames, height, width stored in config are used as defaults/for seq_len calculation, but inputs can have different dimensions as long as they're valid. """ assert tokens.dim() == 4, f"expected [B, F, H, W] tokens, got {tokens.shape}" # Dynamic dimensions - no strict dimension checks, WanModel handles variable sizes # [B, F, H, W, in_dim] # Ensure output dtype matches token_embedding weight dtype x = self.token_embedding(tokens) # Ensure dtype matches model's expected dtype (usually bfloat16 for mixed precision) token_embedding_dtype = self.token_embedding.weight.dtype x = x.to(dtype=token_embedding_dtype) # [B, in_dim, F, H, W] x = x.permute(0, 4, 1, 2, 3).contiguous() # WanModel expects a list of [C_in, F, H, W] return [x_i for x_i in x] def _text_to_list(self, encoder_hidden_states: torch.Tensor) -> list: r""" Convert batched text embeddings ``[B, L, C]`` into the list-of-tensors format expected by :class:`WanModel`. """ assert encoder_hidden_states.dim() == 3, ( f"expected encoder_hidden_states [B, L, C], got {encoder_hidden_states.shape}") return [e for e in encoder_hidden_states] def _set_gradient_checkpointing(self, enable=True, gradient_checkpointing_func=None): """Set gradient checkpointing for the module.""" self.gradient_checkpointing = enable def forward( self, tokens: torch.LongTensor, timesteps: torch.LongTensor, encoder_hidden_states: torch.FloatTensor, y: Optional[list] = None, ) -> torch.FloatTensor: r""" Forward pass of the **discrete video transformer**. Args: tokens (`torch.LongTensor` of shape `[B, F, H, W]`): Discrete codebook indices (e.g. from a 2D VQ-VAE applied frame‑wise). timesteps (`torch.LongTensor` of shape `[B]` or `[B, F * H * W]`): Diffusion timestep(s), following the same semantics as Meissonic's scalar timesteps. encoder_hidden_states (`torch.FloatTensor` of shape `[B, L, C_text]`): Text embeddings (e.g. from CLIP). Each sample corresponds to one video. y (`Optional[list]`): Optional conditional video list passed to the underlying :class:`WanModel` for i2v / ti2v / s2v variants. For now this is surfaced as a raw passthrough and can be left as ``None`` for pure text‑to‑video. Returns: `torch.FloatTensor`: Logits over the codebook of shape `[B, codebook_size, F, H_out, W_out]`, where `(H_out, W_out)` depend on the Wan patch configuration. For the default `patch_size=(1, 2, 2)` and input ``H=W=height``, we have ``H_out = height // 2`` and ``W_out = width // 2``. """ device = tokens.device if DEBUG_TRANSFORMER: print(f"[DEBUG-transformer] Input: tokens.shape={tokens.shape}, encoder_hidden_states.shape={encoder_hidden_states.shape}, timesteps.shape={timesteps.shape}") x_list = self._tokens_to_video(tokens) context_list = self._text_to_list(encoder_hidden_states) if DEBUG_TRANSFORMER: print(f"[DEBUG-transformer] After conversion: len(x_list)={len(x_list)}, len(context_list)={len(context_list)}") if len(x_list) > 0: print(f"[DEBUG-transformer] x_list[0].shape={x_list[0].shape}") if len(context_list) > 0: print(f"[DEBUG-transformer] context_list[0].shape={context_list[0].shape}") # Calculate seq_len from actual input dimensions (supports dynamic sizes) # tokens: [B, F, H, W] -> after patchification: seq_len = F * (H/p_h) * (W/p_w) _, f_in, h_in, w_in = tokens.shape h_patch = h_in // self.backbone.patch_size[1] w_patch = w_in // self.backbone.patch_size[2] seq_len = f_in * h_patch * w_patch # Prepare timesteps in the exact shape WanModel.forward expects. # Its current implementation assumes `t` is either [B, seq_len] or will be # expanded from 1D; the 1D branch is slightly buggy for non-singleton dims, # so we always give it a [B, seq_len] tensor here. if timesteps.dim() == 1: # [B] -> [B, 1] -> [B, seq_len] (broadcast along sequence) t_model = timesteps.to(device).unsqueeze(1).expand(-1, seq_len) elif timesteps.dim() == 2: assert timesteps.size(1) == seq_len, ( f"Expected timesteps second dim == seq_len ({seq_len}), " f"but got {timesteps.size(1)}" ) t_model = timesteps.to(device) else: raise ValueError( f"Unsupported timesteps shape {timesteps.shape}; " "expected [B] or [B, seq_len]" ) if DEBUG_TRANSFORMER: print(f"[DEBUG-transformer] t_model.shape={t_model.shape}") # WanModel.forward expects: # x: List[Tensor [C_in, F, H, W]] # t: Tensor [B] or [B, seq_len] # context: List[Tensor [L, C_text]] # seq_len: int # y: Optional[List[Tensor]] if self.training and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): # Unpack inputs: x_list, t, context_list, seq_len, y x_in, t_in, context_in, seq_len_in, y_in = inputs return module(x=x_in, t=t_in, context=context_in, seq_len=seq_len_in, y=y_in) return custom_forward # Use gradient checkpointing for the backbone ckpt_kwargs = {"use_reentrant": False} out_list = torch.utils.checkpoint.checkpoint( create_custom_forward(self.backbone), x_list, t_model, context_list, seq_len, y, **ckpt_kwargs, ) else: out_list = self.backbone( x=x_list, t=t_model, context=context_list, seq_len=seq_len, y=y, ) if DEBUG_TRANSFORMER: print(f"[DEBUG-transformer] After backbone: len(out_list)={len(out_list)}") if len(out_list) > 0: print(f"[DEBUG-transformer] out_list[0].shape={out_list[0].shape}") # out_list: length B, each [C_out, F, H_out, W_out] vids = torch.stack(out_list, dim=0) # [B, C_out, F, H_out, W_out] if DEBUG_TRANSFORMER: print(f"[DEBUG-transformer] After stack: vids.shape={vids.shape}") # Ensure vids dtype matches logits_head weight dtype vids = vids.to(dtype=self.logits_head.weight.dtype) logits = self.logits_head(vids) # [B, vocab_size, F, H_out, W_out] where vocab_size = codebook_size + 1 if DEBUG_TRANSFORMER: print(f"[DEBUG-transformer] Final logits.shape={logits.shape}") return logits # def _available_device(): # return "cuda" if torch.cuda.is_available() else "cpu" # def test_wan_discrete_video_transformer_forward_and_shapes(): # """ # Basic smoke test: # - build a tiny WanDiscreteVideoTransformer # - run a forward pass with random pseudo-video tokens + random text # - check output shapes, parameter count and (if CUDA present) memory usage # """ # device = _available_device() # # small config to keep the test lightweight # codebook_size = 128 # vocab_size = codebook_size + 1 # reserve one for mask if needed later # num_frames = 2 # height = 16 # width = 16 # model = WanDiscreteVideoTransformer( # codebook_size=codebook_size, # vocab_size=vocab_size, # num_frames=num_frames, # height=height, # width=width, # # shrink Wan backbone for the unit test # in_dim=32, # dim=64, # ffn_dim=128, # freq_dim=32, # text_dim=64, # out_dim=32, # num_heads=4, # num_layers=2, # ).to(device) # model.eval() # batch_size = 2 # # pseudo-video tokens from 2D VQ-VAE on frames: [B, F, H, W] # tokens = torch.randint( # low=0, # high=codebook_size, # size=(batch_size, num_frames, height, width), # dtype=torch.long, # device=device, # ) # # text: [B, L, C_text] # text_seq_len = 8 # encoder_hidden_states = torch.randn( # batch_size, text_seq_len, model.backbone.text_dim, device=device # ) # # timesteps: [B] # timesteps = torch.randint( # low=0, high=1000, size=(batch_size,), dtype=torch.long, device=device # ) # # track memory if CUDA is available # if device == "cuda": # torch.cuda.reset_peak_memory_stats() # mem_before = torch.cuda.memory_allocated() # else: # mem_before = 0 # with torch.no_grad(): # logits = model( # tokens=tokens, # timesteps=timesteps, # encoder_hidden_states=encoder_hidden_states, # y=None, # ) # if device == "cuda": # mem_after = torch.cuda.memory_allocated() # peak_mem = torch.cuda.max_memory_allocated() # else: # mem_after = mem_before # peak_mem = mem_before # # logits: [B, codebook_size, F, H_out, W_out] # assert logits.shape[0] == batch_size # assert logits.shape[1] == codebook_size # assert logits.shape[2] == num_frames # # WanModel returns unpatchified videos, so spatial size matches the input grid. # h_out = height # w_out = width # assert logits.shape[3] == h_out # assert logits.shape[4] == w_out # # parameter count sanity check (just ensure it's > 0 and finite) # num_params = sum(p.numel() for p in model.parameters()) # assert num_params > 0 # assert math.isfinite(float(num_params)) # # memory sanity check (on CUDA the forward pass should allocate > 0 bytes) # if device == "cuda": # assert peak_mem >= mem_after >= mem_before # import torch # from safetensors import safe_open # # from src.transformer_video import WanDiscreteVideoTransformer # ckpt_path = "/mnt/Meissonic/model/diffusion_pytorch_model.safetensors" # # 1) 按你想匹配 wan2.1 的超参实例化(这里写一份常用配置,务必与 ckpt 对齐) # model = WanDiscreteVideoTransformer( # codebook_size=128, # 离散侧自定义 # vocab_size=129, # num_frames=2, # height=16, # width=16, # # Wan backbone 超参需与 ckpt 完全一致 # model_type="t2v", # patch_size=(1, 2, 2), # in_dim=16, # dim=1536, # ffn_dim=8960, # freq_dim=256, # text_dim=4096, # out_dim=16, # num_heads=12, # num_layers=30, # window_size=(-1, -1), # qk_norm=True, # cross_attn_norm=True, # eps=1e-6, # ) # # 2) 读取 safetensors # state_dict = {} # with safe_open(ckpt_path, framework="pt", device="cpu") as f: # for k in f.keys(): # state_dict[k] = f.get_tensor(k) # # 3) 尝试加载到 backbone(不碰 token_embedding/logits_head) # missing, unexpected = model.backbone.load_state_dict(state_dict, strict=False) # print("Missing keys:", missing[:50], "... total", len(missing)) # print("Unexpected keys:", unexpected[:50], "... total", len(unexpected)) # print("Backbone params (M):", sum(p.numel() for p in model.backbone.parameters()) / 1e6) # print("Params (M):", sum(p.numel() for p in model.parameters()) / 1e6) # # if __name__ == '__main__': # # # test_wan_discrete_video_transformer_forward_and_shapes() # # print('WanDiscreteVideoTransformer forward pass test: PASSED')