the-sweater-cat's picture
Upload modeling_wan_camera.py with huggingface_hub
795c1a4 verified
"""
WanTransformer3DModel with camera control adapter for Wan2.1-Fun-V1.1-1.3B-Control-Camera.
The camera adapter processes 6-channel Plucker ray embeddings (temporally packed
to 24 channels) and additively injects them into the patch-embedded latents.
Architecture matches VideoX-Fun's SimpleAdapter exactly.
Usage:
from modeling_wan_camera import WanCameraControlTransformer3DModel
model = WanCameraControlTransformer3DModel.from_pretrained("path/to/transformer")
"""
import math
from typing import Any
import torch
import torch.nn as nn
from diffusers import WanTransformer3DModel
from diffusers.configuration_utils import register_to_config
from diffusers.models.modeling_outputs import Transformer2DModelOutput
from diffusers.utils import apply_lora_scale
class ResidualBlock(nn.Module):
def __init__(self, dim):
super().__init__()
self.conv1 = nn.Conv2d(dim, dim, kernel_size=3, padding=1)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(dim, dim, kernel_size=3, padding=1)
def forward(self, x):
residual = x
out = self.relu(self.conv1(x))
out = self.conv2(out)
out += residual
return out
class CameraControlAdapter(nn.Module):
"""
Processes per-frame Plucker ray embeddings into features matching the
transformer's patch-embedded latent shape.
Pipeline: PixelUnshuffle(8) -> Conv2d -> ResidualBlock(s)
Input: [B, 24, F, H_pixel, W_pixel] (pixel-resolution camera embeddings)
Output: [B, inner_dim, F, H_latent/p_h, W_latent/p_w] (matches patch_embedding output)
"""
def __init__(self, in_channels, out_channels, kernel_size, stride,
downscale_factor=8, num_residual_blocks=1):
super().__init__()
self.pixel_unshuffle = nn.PixelUnshuffle(downscale_factor=downscale_factor)
self.conv = nn.Conv2d(
in_channels * downscale_factor * downscale_factor,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=0,
)
self.residual_blocks = nn.Sequential(
*[ResidualBlock(out_channels) for _ in range(num_residual_blocks)]
)
def forward(self, x):
bs, c, f, h, w = x.size()
x = x.permute(0, 2, 1, 3, 4).contiguous().view(bs * f, c, h, w)
x = self.pixel_unshuffle(x)
x = self.conv(x)
x = self.residual_blocks(x)
x = x.view(bs, f, x.size(1), x.size(2), x.size(3))
x = x.permute(0, 2, 1, 3, 4)
return x
class WanCameraControlTransformer3DModel(WanTransformer3DModel):
"""
WanTransformer3DModel + camera control adapter.
The adapter output is added to patch-embedded latents before the transformer
blocks, matching VideoX-Fun's y_camera injection point.
Extra config params vs base WanTransformer3DModel:
control_adapter_in_channels (int): Input channels for camera embeddings (default 24)
control_adapter_downscale_factor (int): PixelUnshuffle factor (default 8)
control_adapter_num_residual_blocks (int): Number of residual blocks (default 1)
Extra forward param:
control_camera_video (Tensor | None): [B, 24, F, H_px, W_px] Plucker ray embeddings
"""
@register_to_config
def __init__(
self,
# Base WanTransformer3DModel params (must be explicit for @register_to_config)
patch_size: tuple[int, ...] = (1, 2, 2),
num_attention_heads: int = 12,
attention_head_dim: int = 128,
in_channels: int = 32,
out_channels: int = 16,
text_dim: int = 4096,
freq_dim: int = 256,
ffn_dim: int = 8960,
num_layers: int = 30,
cross_attn_norm: bool = True,
qk_norm: str | None = "rms_norm_across_heads",
eps: float = 1e-6,
image_dim: int | None = 1280,
added_kv_proj_dim: int | None = 1536,
rope_max_seq_len: int = 1024,
pos_embed_seq_len: int | None = None,
# Camera adapter params
control_adapter_in_channels: int = 24,
control_adapter_downscale_factor: int = 8,
control_adapter_num_residual_blocks: int = 1,
):
super().__init__(
patch_size=patch_size,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
in_channels=in_channels,
out_channels=out_channels,
text_dim=text_dim,
freq_dim=freq_dim,
ffn_dim=ffn_dim,
num_layers=num_layers,
cross_attn_norm=cross_attn_norm,
qk_norm=qk_norm,
eps=eps,
image_dim=image_dim,
added_kv_proj_dim=added_kv_proj_dim,
rope_max_seq_len=rope_max_seq_len,
pos_embed_seq_len=pos_embed_seq_len,
)
inner_dim = num_attention_heads * attention_head_dim
ps = patch_size if isinstance(patch_size, (list, tuple)) else [patch_size] * 3
self.control_adapter = CameraControlAdapter(
in_channels=control_adapter_in_channels,
out_channels=inner_dim,
kernel_size=tuple(ps[1:]),
stride=tuple(ps[1:]),
downscale_factor=control_adapter_downscale_factor,
num_residual_blocks=control_adapter_num_residual_blocks,
)
@apply_lora_scale("attention_kwargs")
def forward(
self,
hidden_states: torch.Tensor,
timestep: torch.LongTensor,
encoder_hidden_states: torch.Tensor,
encoder_hidden_states_image: torch.Tensor | None = None,
control_camera_video: torch.Tensor | None = None,
return_dict: bool = True,
attention_kwargs: dict[str, Any] | None = None,
) -> torch.Tensor | dict[str, torch.Tensor]:
batch_size, num_channels, num_frames, height, width = hidden_states.shape
p_t, p_h, p_w = self.config.patch_size
post_patch_num_frames = num_frames // p_t
post_patch_height = height // p_h
post_patch_width = width // p_w
rotary_emb = self.rope(hidden_states)
hidden_states = self.patch_embedding(hidden_states)
# Camera adapter: additive injection after patch embedding
if control_camera_video is not None:
hidden_states = hidden_states + self.control_adapter(control_camera_video)
hidden_states = hidden_states.flatten(2).transpose(1, 2)
if timestep.ndim == 2:
ts_seq_len = timestep.shape[1]
timestep = timestep.flatten()
else:
ts_seq_len = None
temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder(
timestep, encoder_hidden_states, encoder_hidden_states_image, timestep_seq_len=ts_seq_len
)
if ts_seq_len is not None:
timestep_proj = timestep_proj.unflatten(2, (6, -1))
else:
timestep_proj = timestep_proj.unflatten(1, (6, -1))
if encoder_hidden_states_image is not None:
encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1)
if torch.is_grad_enabled() and self.gradient_checkpointing:
for block in self.blocks:
hidden_states = self._gradient_checkpointing_func(
block, hidden_states, encoder_hidden_states, timestep_proj, rotary_emb
)
else:
for block in self.blocks:
hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb)
if temb.ndim == 3:
shift, scale = (self.scale_shift_table.unsqueeze(0).to(temb.device) + temb.unsqueeze(2)).chunk(2, dim=2)
shift = shift.squeeze(2)
scale = scale.squeeze(2)
else:
shift, scale = (self.scale_shift_table.to(temb.device) + temb.unsqueeze(1)).chunk(2, dim=1)
shift = shift.to(hidden_states.device)
scale = scale.to(hidden_states.device)
hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states)
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.reshape(
batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1
)
hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)