Instructions to use the-sweater-cat/Wan2.1-Fun-V1.1-1.3B-Control-Camera-Diffusers with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use the-sweater-cat/Wan2.1-Fun-V1.1-1.3B-Control-Camera-Diffusers with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline from diffusers.utils import load_image, export_to_video # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("the-sweater-cat/Wan2.1-Fun-V1.1-1.3B-Control-Camera-Diffusers", dtype=torch.bfloat16, device_map="cuda") pipe.to("cuda") prompt = "A man with short gray hair plays a red electric guitar." image = load_image( "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/guitar-man.png" ) output = pipe(image=image, prompt=prompt).frames[0] export_to_video(output, "output.mp4") - Notebooks
- Google Colab
- Kaggle
| """ | |
| WanTransformer3DModel with camera control adapter for Wan2.1-Fun-V1.1-1.3B-Control-Camera. | |
| The camera adapter processes 6-channel Plucker ray embeddings (temporally packed | |
| to 24 channels) and additively injects them into the patch-embedded latents. | |
| Architecture matches VideoX-Fun's SimpleAdapter exactly. | |
| Usage: | |
| from modeling_wan_camera import WanCameraControlTransformer3DModel | |
| model = WanCameraControlTransformer3DModel.from_pretrained("path/to/transformer") | |
| """ | |
| import math | |
| from typing import Any | |
| import torch | |
| import torch.nn as nn | |
| from diffusers import WanTransformer3DModel | |
| from diffusers.configuration_utils import register_to_config | |
| from diffusers.models.modeling_outputs import Transformer2DModelOutput | |
| from diffusers.utils import apply_lora_scale | |
| class ResidualBlock(nn.Module): | |
| def __init__(self, dim): | |
| super().__init__() | |
| self.conv1 = nn.Conv2d(dim, dim, kernel_size=3, padding=1) | |
| self.relu = nn.ReLU(inplace=True) | |
| self.conv2 = nn.Conv2d(dim, dim, kernel_size=3, padding=1) | |
| def forward(self, x): | |
| residual = x | |
| out = self.relu(self.conv1(x)) | |
| out = self.conv2(out) | |
| out += residual | |
| return out | |
| class CameraControlAdapter(nn.Module): | |
| """ | |
| Processes per-frame Plucker ray embeddings into features matching the | |
| transformer's patch-embedded latent shape. | |
| Pipeline: PixelUnshuffle(8) -> Conv2d -> ResidualBlock(s) | |
| Input: [B, 24, F, H_pixel, W_pixel] (pixel-resolution camera embeddings) | |
| Output: [B, inner_dim, F, H_latent/p_h, W_latent/p_w] (matches patch_embedding output) | |
| """ | |
| def __init__(self, in_channels, out_channels, kernel_size, stride, | |
| downscale_factor=8, num_residual_blocks=1): | |
| super().__init__() | |
| self.pixel_unshuffle = nn.PixelUnshuffle(downscale_factor=downscale_factor) | |
| self.conv = nn.Conv2d( | |
| in_channels * downscale_factor * downscale_factor, | |
| out_channels, | |
| kernel_size=kernel_size, | |
| stride=stride, | |
| padding=0, | |
| ) | |
| self.residual_blocks = nn.Sequential( | |
| *[ResidualBlock(out_channels) for _ in range(num_residual_blocks)] | |
| ) | |
| def forward(self, x): | |
| bs, c, f, h, w = x.size() | |
| x = x.permute(0, 2, 1, 3, 4).contiguous().view(bs * f, c, h, w) | |
| x = self.pixel_unshuffle(x) | |
| x = self.conv(x) | |
| x = self.residual_blocks(x) | |
| x = x.view(bs, f, x.size(1), x.size(2), x.size(3)) | |
| x = x.permute(0, 2, 1, 3, 4) | |
| return x | |
| class WanCameraControlTransformer3DModel(WanTransformer3DModel): | |
| """ | |
| WanTransformer3DModel + camera control adapter. | |
| The adapter output is added to patch-embedded latents before the transformer | |
| blocks, matching VideoX-Fun's y_camera injection point. | |
| Extra config params vs base WanTransformer3DModel: | |
| control_adapter_in_channels (int): Input channels for camera embeddings (default 24) | |
| control_adapter_downscale_factor (int): PixelUnshuffle factor (default 8) | |
| control_adapter_num_residual_blocks (int): Number of residual blocks (default 1) | |
| Extra forward param: | |
| control_camera_video (Tensor | None): [B, 24, F, H_px, W_px] Plucker ray embeddings | |
| """ | |
| def __init__( | |
| self, | |
| # Base WanTransformer3DModel params (must be explicit for @register_to_config) | |
| patch_size: tuple[int, ...] = (1, 2, 2), | |
| num_attention_heads: int = 12, | |
| attention_head_dim: int = 128, | |
| in_channels: int = 32, | |
| out_channels: int = 16, | |
| text_dim: int = 4096, | |
| freq_dim: int = 256, | |
| ffn_dim: int = 8960, | |
| num_layers: int = 30, | |
| cross_attn_norm: bool = True, | |
| qk_norm: str | None = "rms_norm_across_heads", | |
| eps: float = 1e-6, | |
| image_dim: int | None = 1280, | |
| added_kv_proj_dim: int | None = 1536, | |
| rope_max_seq_len: int = 1024, | |
| pos_embed_seq_len: int | None = None, | |
| # Camera adapter params | |
| control_adapter_in_channels: int = 24, | |
| control_adapter_downscale_factor: int = 8, | |
| control_adapter_num_residual_blocks: int = 1, | |
| ): | |
| super().__init__( | |
| patch_size=patch_size, | |
| num_attention_heads=num_attention_heads, | |
| attention_head_dim=attention_head_dim, | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| text_dim=text_dim, | |
| freq_dim=freq_dim, | |
| ffn_dim=ffn_dim, | |
| num_layers=num_layers, | |
| cross_attn_norm=cross_attn_norm, | |
| qk_norm=qk_norm, | |
| eps=eps, | |
| image_dim=image_dim, | |
| added_kv_proj_dim=added_kv_proj_dim, | |
| rope_max_seq_len=rope_max_seq_len, | |
| pos_embed_seq_len=pos_embed_seq_len, | |
| ) | |
| inner_dim = num_attention_heads * attention_head_dim | |
| ps = patch_size if isinstance(patch_size, (list, tuple)) else [patch_size] * 3 | |
| self.control_adapter = CameraControlAdapter( | |
| in_channels=control_adapter_in_channels, | |
| out_channels=inner_dim, | |
| kernel_size=tuple(ps[1:]), | |
| stride=tuple(ps[1:]), | |
| downscale_factor=control_adapter_downscale_factor, | |
| num_residual_blocks=control_adapter_num_residual_blocks, | |
| ) | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| timestep: torch.LongTensor, | |
| encoder_hidden_states: torch.Tensor, | |
| encoder_hidden_states_image: torch.Tensor | None = None, | |
| control_camera_video: torch.Tensor | None = None, | |
| return_dict: bool = True, | |
| attention_kwargs: dict[str, Any] | None = None, | |
| ) -> torch.Tensor | dict[str, torch.Tensor]: | |
| batch_size, num_channels, num_frames, height, width = hidden_states.shape | |
| p_t, p_h, p_w = self.config.patch_size | |
| post_patch_num_frames = num_frames // p_t | |
| post_patch_height = height // p_h | |
| post_patch_width = width // p_w | |
| rotary_emb = self.rope(hidden_states) | |
| hidden_states = self.patch_embedding(hidden_states) | |
| # Camera adapter: additive injection after patch embedding | |
| if control_camera_video is not None: | |
| hidden_states = hidden_states + self.control_adapter(control_camera_video) | |
| hidden_states = hidden_states.flatten(2).transpose(1, 2) | |
| if timestep.ndim == 2: | |
| ts_seq_len = timestep.shape[1] | |
| timestep = timestep.flatten() | |
| else: | |
| ts_seq_len = None | |
| temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder( | |
| timestep, encoder_hidden_states, encoder_hidden_states_image, timestep_seq_len=ts_seq_len | |
| ) | |
| if ts_seq_len is not None: | |
| timestep_proj = timestep_proj.unflatten(2, (6, -1)) | |
| else: | |
| timestep_proj = timestep_proj.unflatten(1, (6, -1)) | |
| if encoder_hidden_states_image is not None: | |
| encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1) | |
| if torch.is_grad_enabled() and self.gradient_checkpointing: | |
| for block in self.blocks: | |
| hidden_states = self._gradient_checkpointing_func( | |
| block, hidden_states, encoder_hidden_states, timestep_proj, rotary_emb | |
| ) | |
| else: | |
| for block in self.blocks: | |
| hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb) | |
| if temb.ndim == 3: | |
| shift, scale = (self.scale_shift_table.unsqueeze(0).to(temb.device) + temb.unsqueeze(2)).chunk(2, dim=2) | |
| shift = shift.squeeze(2) | |
| scale = scale.squeeze(2) | |
| else: | |
| shift, scale = (self.scale_shift_table.to(temb.device) + temb.unsqueeze(1)).chunk(2, dim=1) | |
| shift = shift.to(hidden_states.device) | |
| scale = scale.to(hidden_states.device) | |
| hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states) | |
| hidden_states = self.proj_out(hidden_states) | |
| hidden_states = hidden_states.reshape( | |
| batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1 | |
| ) | |
| hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6) | |
| output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3) | |
| if not return_dict: | |
| return (output,) | |
| return Transformer2DModelOutput(sample=output) | |