""" WanTransformer3DModel with camera control adapter for Wan2.1-Fun-V1.1-1.3B-Control-Camera. The camera adapter processes 6-channel Plucker ray embeddings (temporally packed to 24 channels) and additively injects them into the patch-embedded latents. Architecture matches VideoX-Fun's SimpleAdapter exactly. Usage: from modeling_wan_camera import WanCameraControlTransformer3DModel model = WanCameraControlTransformer3DModel.from_pretrained("path/to/transformer") """ import math from typing import Any import torch import torch.nn as nn from diffusers import WanTransformer3DModel from diffusers.configuration_utils import register_to_config from diffusers.models.modeling_outputs import Transformer2DModelOutput from diffusers.utils import apply_lora_scale class ResidualBlock(nn.Module): def __init__(self, dim): super().__init__() self.conv1 = nn.Conv2d(dim, dim, kernel_size=3, padding=1) self.relu = nn.ReLU(inplace=True) self.conv2 = nn.Conv2d(dim, dim, kernel_size=3, padding=1) def forward(self, x): residual = x out = self.relu(self.conv1(x)) out = self.conv2(out) out += residual return out class CameraControlAdapter(nn.Module): """ Processes per-frame Plucker ray embeddings into features matching the transformer's patch-embedded latent shape. Pipeline: PixelUnshuffle(8) -> Conv2d -> ResidualBlock(s) Input: [B, 24, F, H_pixel, W_pixel] (pixel-resolution camera embeddings) Output: [B, inner_dim, F, H_latent/p_h, W_latent/p_w] (matches patch_embedding output) """ def __init__(self, in_channels, out_channels, kernel_size, stride, downscale_factor=8, num_residual_blocks=1): super().__init__() self.pixel_unshuffle = nn.PixelUnshuffle(downscale_factor=downscale_factor) self.conv = nn.Conv2d( in_channels * downscale_factor * downscale_factor, out_channels, kernel_size=kernel_size, stride=stride, padding=0, ) self.residual_blocks = nn.Sequential( *[ResidualBlock(out_channels) for _ in range(num_residual_blocks)] ) def forward(self, x): bs, c, f, h, w = x.size() x = x.permute(0, 2, 1, 3, 4).contiguous().view(bs * f, c, h, w) x = self.pixel_unshuffle(x) x = self.conv(x) x = self.residual_blocks(x) x = x.view(bs, f, x.size(1), x.size(2), x.size(3)) x = x.permute(0, 2, 1, 3, 4) return x class WanCameraControlTransformer3DModel(WanTransformer3DModel): """ WanTransformer3DModel + camera control adapter. The adapter output is added to patch-embedded latents before the transformer blocks, matching VideoX-Fun's y_camera injection point. Extra config params vs base WanTransformer3DModel: control_adapter_in_channels (int): Input channels for camera embeddings (default 24) control_adapter_downscale_factor (int): PixelUnshuffle factor (default 8) control_adapter_num_residual_blocks (int): Number of residual blocks (default 1) Extra forward param: control_camera_video (Tensor | None): [B, 24, F, H_px, W_px] Plucker ray embeddings """ @register_to_config def __init__( self, # Base WanTransformer3DModel params (must be explicit for @register_to_config) patch_size: tuple[int, ...] = (1, 2, 2), num_attention_heads: int = 12, attention_head_dim: int = 128, in_channels: int = 32, out_channels: int = 16, text_dim: int = 4096, freq_dim: int = 256, ffn_dim: int = 8960, num_layers: int = 30, cross_attn_norm: bool = True, qk_norm: str | None = "rms_norm_across_heads", eps: float = 1e-6, image_dim: int | None = 1280, added_kv_proj_dim: int | None = 1536, rope_max_seq_len: int = 1024, pos_embed_seq_len: int | None = None, # Camera adapter params control_adapter_in_channels: int = 24, control_adapter_downscale_factor: int = 8, control_adapter_num_residual_blocks: int = 1, ): super().__init__( patch_size=patch_size, num_attention_heads=num_attention_heads, attention_head_dim=attention_head_dim, in_channels=in_channels, out_channels=out_channels, text_dim=text_dim, freq_dim=freq_dim, ffn_dim=ffn_dim, num_layers=num_layers, cross_attn_norm=cross_attn_norm, qk_norm=qk_norm, eps=eps, image_dim=image_dim, added_kv_proj_dim=added_kv_proj_dim, rope_max_seq_len=rope_max_seq_len, pos_embed_seq_len=pos_embed_seq_len, ) inner_dim = num_attention_heads * attention_head_dim ps = patch_size if isinstance(patch_size, (list, tuple)) else [patch_size] * 3 self.control_adapter = CameraControlAdapter( in_channels=control_adapter_in_channels, out_channels=inner_dim, kernel_size=tuple(ps[1:]), stride=tuple(ps[1:]), downscale_factor=control_adapter_downscale_factor, num_residual_blocks=control_adapter_num_residual_blocks, ) @apply_lora_scale("attention_kwargs") def forward( self, hidden_states: torch.Tensor, timestep: torch.LongTensor, encoder_hidden_states: torch.Tensor, encoder_hidden_states_image: torch.Tensor | None = None, control_camera_video: torch.Tensor | None = None, return_dict: bool = True, attention_kwargs: dict[str, Any] | None = None, ) -> torch.Tensor | dict[str, torch.Tensor]: batch_size, num_channels, num_frames, height, width = hidden_states.shape p_t, p_h, p_w = self.config.patch_size post_patch_num_frames = num_frames // p_t post_patch_height = height // p_h post_patch_width = width // p_w rotary_emb = self.rope(hidden_states) hidden_states = self.patch_embedding(hidden_states) # Camera adapter: additive injection after patch embedding if control_camera_video is not None: hidden_states = hidden_states + self.control_adapter(control_camera_video) hidden_states = hidden_states.flatten(2).transpose(1, 2) if timestep.ndim == 2: ts_seq_len = timestep.shape[1] timestep = timestep.flatten() else: ts_seq_len = None temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder( timestep, encoder_hidden_states, encoder_hidden_states_image, timestep_seq_len=ts_seq_len ) if ts_seq_len is not None: timestep_proj = timestep_proj.unflatten(2, (6, -1)) else: timestep_proj = timestep_proj.unflatten(1, (6, -1)) if encoder_hidden_states_image is not None: encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1) if torch.is_grad_enabled() and self.gradient_checkpointing: for block in self.blocks: hidden_states = self._gradient_checkpointing_func( block, hidden_states, encoder_hidden_states, timestep_proj, rotary_emb ) else: for block in self.blocks: hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb) if temb.ndim == 3: shift, scale = (self.scale_shift_table.unsqueeze(0).to(temb.device) + temb.unsqueeze(2)).chunk(2, dim=2) shift = shift.squeeze(2) scale = scale.squeeze(2) else: shift, scale = (self.scale_shift_table.to(temb.device) + temb.unsqueeze(1)).chunk(2, dim=1) shift = shift.to(hidden_states.device) scale = scale.to(hidden_states.device) hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states) hidden_states = self.proj_out(hidden_states) hidden_states = hidden_states.reshape( batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1 ) hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6) output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3) if not return_dict: return (output,) return Transformer2DModelOutput(sample=output)