| | import torch |
| | import torch.nn.functional as F |
| | import numpy as np |
| | from typing import Optional, List, Union, Callable, Tuple |
| | from PIL import Image, ImageOps |
| | from einops import rearrange |
| | from tqdm import tqdm |
| | from diffusers import DiffusionPipeline |
| |
|
| |
|
| | def sinusoidal_embedding_1d(dim, position): |
| | """1D sinusoidal positional embedding for timesteps.""" |
| | sinusoid = torch.outer( |
| | position.type(torch.float64), |
| | torch.pow( |
| | 10000, |
| | -torch.arange(dim // 2, dtype=torch.float64, device=position.device).div( |
| | dim // 2 |
| | ), |
| | ), |
| | ) |
| | x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1) |
| | return x.to(position.dtype) |
| |
|
| |
|
| | def _build_rope_3d(rope_module, f, h, w, device): |
| | """ |
| | Build 3D RoPE (cos, sin) for a given (f, h, w) grid using the |
| | WanRotaryPosEmbed module's precomputed buffers. |
| | |
| | Returns: |
| | (freqs_cos, freqs_sin) each of shape [1, f*h*w, 1, head_dim] |
| | """ |
| | split_sizes = [rope_module.t_dim, rope_module.h_dim, rope_module.w_dim] |
| | cos_parts = rope_module.freqs_cos.split(split_sizes, dim=1) |
| | sin_parts = rope_module.freqs_sin.split(split_sizes, dim=1) |
| |
|
| | cos_f = cos_parts[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1) |
| | cos_h = cos_parts[1][:h].view(1, h, 1, -1).expand(f, h, w, -1) |
| | cos_w = cos_parts[2][:w].view(1, 1, w, -1).expand(f, h, w, -1) |
| |
|
| | sin_f = sin_parts[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1) |
| | sin_h = sin_parts[1][:h].view(1, h, 1, -1).expand(f, h, w, -1) |
| | sin_w = sin_parts[2][:w].view(1, 1, w, -1).expand(f, h, w, -1) |
| |
|
| | freqs_cos = torch.cat([cos_f, cos_h, cos_w], dim=-1).reshape(1, f * h * w, 1, -1).to(device) |
| | freqs_sin = torch.cat([sin_f, sin_h, sin_w], dim=-1).reshape(1, f * h * w, 1, -1).to(device) |
| | return freqs_cos, freqs_sin |
| |
|
| |
|
| | class KiwiEditPipeline(DiffusionPipeline): |
| | """ |
| | Pipeline for reference-guided video and image editing using KiwiEdit. |
| | |
| | This pipeline uses a Qwen2.5-VL multimodal LLM encoder for understanding |
| | editing instructions with source visual context, a WanTransformer3DModel |
| | for diffusion, and AutoencoderKLWan for VAE encoding/decoding. |
| | |
| | Args: |
| | transformer: WanTransformer3DModel - DiT backbone for denoising. |
| | vae: AutoencoderKLWan - 3D causal VAE. |
| | scheduler: FlowMatchEulerDiscreteScheduler or compatible scheduler. |
| | mllm_encoder: MLLMEncoder - Qwen2.5-VL MLLM with learnable queries. |
| | processor: AutoProcessor - Qwen2.5-VL processor/tokenizer bundle. |
| | source_embedder: ConditionalEmbedder - VAE source conditioning. |
| | ref_embedder: ConditionalEmbedder - VAE reference conditioning. |
| | """ |
| |
|
| | model_cpu_offload_seq = "mllm_encoder->source_embedder->ref_embedder->transformer->vae" |
| |
|
| | def __init__( |
| | self, |
| | transformer, |
| | vae, |
| | scheduler, |
| | mllm_encoder, |
| | source_embedder, |
| | ref_embedder, |
| | processor=None, |
| | ): |
| | super().__init__() |
| | if isinstance(processor, (list, tuple)): |
| | |
| | processor = None |
| | self.register_modules( |
| | transformer=transformer, |
| | vae=vae, |
| | scheduler=scheduler, |
| | mllm_encoder=mllm_encoder, |
| | processor=processor, |
| | source_embedder=source_embedder, |
| | ref_embedder=ref_embedder, |
| | ) |
| | if processor is not None: |
| | self.mllm_encoder.processor = processor |
| |
|
| | |
| | |
| | |
| |
|
| | @staticmethod |
| | def _check_resize(height, width, num_frames, h_div=16, w_div=16, t_div=4, t_rem=1): |
| | """Round height/width/num_frames to valid values.""" |
| | if height % h_div != 0: |
| | height = (height + h_div - 1) // h_div * h_div |
| | if width % w_div != 0: |
| | width = (width + w_div - 1) // w_div * w_div |
| | if num_frames % t_div != t_rem: |
| | num_frames = (num_frames + t_div - 1) // t_div * t_div + t_rem |
| | return height, width, num_frames |
| |
|
| | @staticmethod |
| | def _preprocess_image(image: Image.Image, dtype, device): |
| | """Convert PIL Image to tensor in [-1, 1].""" |
| | arr = np.array(image, dtype=np.float32) |
| | tensor = torch.from_numpy(arr).to(dtype=dtype, device=device) |
| | tensor = tensor / 127.5 - 1.0 |
| | tensor = tensor.permute(2, 0, 1) |
| | return tensor |
| |
|
| | def _preprocess_video(self, frames: List[Image.Image], dtype, device): |
| | """Convert list of PIL Images to tensor [1, C, T, H, W] in [-1, 1].""" |
| | tensors = [self._preprocess_image(f, dtype, device) for f in frames] |
| | video = torch.stack(tensors, dim=1) |
| | return video.unsqueeze(0) |
| |
|
| | @staticmethod |
| | def _vae_output_to_video(vae_output): |
| | """Convert VAE output tensor to list of PIL Images.""" |
| | |
| | if vae_output.dim() == 5: |
| | vae_output = vae_output.squeeze(0).permute(1, 2, 3, 0) |
| | frames = [] |
| | for t in range(vae_output.shape[0]): |
| | frame = ((vae_output[t] + 1.0) * 127.5).clamp(0, 255) |
| | frame = frame.to(device="cpu", dtype=torch.uint8).numpy() |
| | frames.append(Image.fromarray(frame)) |
| | return frames |
| |
|
| | |
| | |
| | |
| |
|
| | def _setup_scheduler(self, num_inference_steps, denoising_strength=1.0, shift=5.0): |
| | """ |
| | Set up flow-match sigmas and timesteps matching the original diffsynth |
| | FlowMatchScheduler with extra_one_step=True and shift. |
| | """ |
| | sigma_min = 0.003 / 1.002 |
| | sigma_max = 1.0 |
| | sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength |
| | |
| | sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps + 1)[:-1] |
| | |
| | sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) |
| | timesteps = sigmas * 1000 |
| | return sigmas, timesteps |
| |
|
| | def _scheduler_step(self, model_output, sigmas, step_index, sample): |
| | """Euler step for flow matching.""" |
| | sigma = sigmas[step_index] |
| | if step_index + 1 >= len(sigmas): |
| | sigma_next = 0.0 |
| | else: |
| | sigma_next = sigmas[step_index + 1] |
| | return sample + model_output * (sigma_next - sigma) |
| |
|
| | def _scheduler_add_noise(self, original_samples, noise, sigmas, step_index): |
| | """Add noise at given timestep for img2img / video2video.""" |
| | sigma = sigmas[step_index] |
| | return (1 - sigma) * original_samples + sigma * noise |
| |
|
| | def _scheduler_get_sigma(self, timestep, sigmas, timesteps): |
| | """Get sigma for a given timestep.""" |
| | timestep_id = torch.argmin((timesteps - timestep).abs()) |
| | return sigmas[timestep_id] |
| |
|
| | |
| | |
| | |
| |
|
| | def _model_forward( |
| | self, |
| | latents, |
| | timestep, |
| | context, |
| | vae_source_input=None, |
| | vae_ref_image=None, |
| | sigmas=None, |
| | timesteps_schedule=None, |
| | ): |
| | """ |
| | Custom DiT forward pass that handles source/ref conditioning. |
| | Mirrors model_fn_wan_video from the original diffsynth pipeline. |
| | """ |
| | device = latents.device |
| | dtype = latents.dtype |
| | t = self.transformer |
| |
|
| | |
| | timestep_emb = sinusoidal_embedding_1d( |
| | t.config.freq_dim, timestep |
| | ).to(dtype) |
| | time_emb = t.condition_embedder.time_embedder(timestep_emb) |
| | |
| | t_mod = t.condition_embedder.time_proj(F.silu(time_emb)).unflatten( |
| | 1, (6, t.config.num_attention_heads * t.config.attention_head_dim) |
| | ) |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | |
| | x = latents |
| | if vae_source_input is not None: |
| | vae_source_cond = self.source_embedder(vae_source_input) |
| | x = t.patch_embedding(x) |
| | |
| | sigma = self._scheduler_get_sigma(timestep, sigmas, timesteps_schedule) |
| | x = x + vae_source_cond * sigma |
| | else: |
| | x = t.patch_embedding(x) |
| |
|
| | f, h, w = x.shape[2:] |
| | x = rearrange(x, "b c f h w -> b (f h w) c").contiguous() |
| |
|
| | |
| | rotary_emb = _build_rope_3d(t.rope, f, h, w, device) |
| |
|
| | |
| | vae_ref_input_length = 0 |
| | if vae_ref_image is not None: |
| | if len(vae_ref_image) > 1: |
| | vae_ref = torch.cat(vae_ref_image, dim=2) |
| | else: |
| | vae_ref = vae_ref_image[0] |
| |
|
| | vae_ref = self.ref_embedder(vae_ref) |
| | ref_f, ref_h, ref_w = vae_ref.shape[2:] |
| | vae_ref = rearrange(vae_ref, "b c f h w -> b (f h w) c").contiguous() |
| |
|
| | |
| | total_f = f + ref_f |
| | rotary_emb = _build_rope_3d(t.rope, total_f, h, w, device) |
| |
|
| | vae_ref_input_length = vae_ref.shape[1] |
| |
|
| | if self.ref_embedder.config.ref_pad_first: |
| | x = torch.cat([vae_ref, x], dim=1) |
| | else: |
| | x = torch.cat([x, vae_ref], dim=1) |
| |
|
| | |
| | for block in t.blocks: |
| | x = block(x, context, t_mod, rotary_emb) |
| |
|
| | |
| | |
| | table = t.scale_shift_table |
| | shift, scale = ( |
| | table.to(device=device) + time_emb.unsqueeze(1) |
| | ).chunk(2, dim=1) |
| | shift = shift.to(device=x.device) |
| | scale = scale.to(device=x.device) |
| | x = (t.norm_out(x.float()) * (1 + scale) + shift).type_as(x) |
| | x = t.proj_out(x) |
| |
|
| | |
| | if vae_ref_image is not None and vae_ref_input_length > 0: |
| | if self.ref_embedder.config.ref_pad_first: |
| | x = x[:, vae_ref_input_length:, :] |
| | else: |
| | x = x[:, :-vae_ref_input_length, :] |
| |
|
| | |
| | patch_size = t.config.patch_size |
| | x = rearrange( |
| | x, |
| | "b (f h w) (x y z c) -> b c (f x) (h y) (w z)", |
| | f=f, h=h, w=w, |
| | x=patch_size[0], y=patch_size[1], z=patch_size[2], |
| | ) |
| | return x |
| |
|
| | |
| | |
| | |
| |
|
| | @torch.no_grad() |
| | def __call__( |
| | self, |
| | prompt: str, |
| | source_video: Optional[List[Image.Image]] = None, |
| | source_input: Optional[List[Image.Image]] = None, |
| | ref_image: Optional[List[Image.Image]] = None, |
| | negative_prompt: Optional[str] = "", |
| | input_video: Optional[List[Image.Image]] = None, |
| | height: int = 480, |
| | width: int = 832, |
| | num_frames: int = 81, |
| | num_inference_steps: int = 50, |
| | guidance_scale: float = 1.0, |
| | sigma_shift: float = 5.0, |
| | denoising_strength: float = 1.0, |
| | seed: Optional[int] = None, |
| | tiled: bool = True, |
| | tile_size: Tuple[int, int] = (30, 52), |
| | tile_stride: Tuple[int, int] = (15, 26), |
| | output_type: str = "pil", |
| | progress_bar: Callable = tqdm, |
| | ) -> List[Image.Image]: |
| | """ |
| | Run KiwiEdit inference. |
| | |
| | Args: |
| | prompt: Editing instruction text. |
| | source_video: Source video/image frames for MLLM context (also used as |
| | source_input if source_input is not provided). |
| | source_input: Source frames for VAE conditioning. If None but source_video |
| | is provided, source_video is used. |
| | ref_image: Optional reference image(s) for guided editing. |
| | negative_prompt: Negative prompt for CFG. |
| | input_video: Optional input video for video-to-video (adds noise then denoises). |
| | height: Output height in pixels. |
| | width: Output width in pixels. |
| | num_frames: Number of output frames (1 for image editing). |
| | num_inference_steps: Number of denoising steps. |
| | guidance_scale: Classifier-free guidance scale. |
| | sigma_shift: Flow matching shift parameter. |
| | denoising_strength: How much noise to add (1.0 = full noise). |
| | seed: Random seed for reproducibility. |
| | tiled: Whether to use tiled VAE encoding/decoding. |
| | tile_size: VAE tile size. |
| | tile_stride: VAE tile stride. |
| | output_type: "pil" for PIL Images, "latent" for raw latents. |
| | progress_bar: Progress bar callable (e.g., tqdm). |
| | |
| | Returns: |
| | List of PIL Images (video frames). |
| | """ |
| | device = self._execution_device |
| | dtype = torch.bfloat16 |
| | |
| | |
| | |
| | height, width, num_frames = self._check_resize( |
| | height, width, num_frames, h_div=32, w_div=32 |
| | ) |
| | |
| | |
| | z_dim = self.vae.config.z_dim |
| | |
| | dim_mult = self.vae.config.get("dim_mult", [1, 2, 4, 4]) |
| | temporal_downsample = self.vae.config.get("temperal_downsample", [False, True, True]) |
| | |
| | |
| | spatial_factor = 2 ** len(dim_mult) |
| | temporal_factor = 2 ** sum(temporal_downsample) |
| |
|
| | |
| | context = None |
| | src_video_for_mllm = source_video |
| | if src_video_for_mllm is not None: |
| | self.mllm_encoder._ensure_qwen_loaded() |
| | if ref_image is not None: |
| | |
| | context = self.mllm_encoder( |
| | prompt, src_video=src_video_for_mllm, ref_image=ref_image |
| | ) |
| | elif len(src_video_for_mllm) == 1: |
| | context = self.mllm_encoder( |
| | prompt, src_image=src_video_for_mllm |
| | ) |
| | else: |
| | context = self.mllm_encoder( |
| | prompt, src_video=src_video_for_mllm |
| | ) |
| | |
| | context_nega = None |
| |
|
| | |
| | sigmas, timesteps = self._setup_scheduler( |
| | num_inference_steps, denoising_strength, sigma_shift |
| | ) |
| | sigmas = sigmas.to(device) |
| | timesteps = timesteps.to(device) |
| |
|
| | |
| | latent_length = (num_frames - 1) // temporal_factor + 1 |
| | latent_h = height // spatial_factor |
| | latent_w = width // spatial_factor |
| | shape = (1, z_dim, latent_length, latent_h, latent_w) |
| |
|
| | generator = None if seed is None else torch.Generator("cpu").manual_seed(seed) |
| | noise = torch.randn(shape, generator=generator, device="cpu", dtype=torch.float32) |
| | noise = noise.to(dtype=dtype, device=device) |
| |
|
| | |
| | vae_source_input = None |
| | |
| | src_for_vae = source_input if source_input is not None else source_video |
| | if src_for_vae is not None: |
| | src_frames = [src_for_vae[i] for i in range(min(num_frames, len(src_for_vae)))] |
| | |
| | src_frames = [f.resize((width, height), Image.LANCZOS) for f in src_frames] |
| | src_tensor = self._preprocess_video(src_frames, dtype=torch.float32, device=device) |
| | vae_source_input = self.vae.encode(src_tensor).latent_dist.sample() |
| | vae_source_input = vae_source_input.to(dtype=dtype) |
| |
|
| | |
| | vae_ref_image = None |
| | if ref_image is not None: |
| | vae_ref_image = [] |
| | for item in ref_image: |
| | target_size = (width, height) |
| | item = ImageOps.pad(item, target_size, color="white", centering=(0.5, 0.5)) |
| | ref_tensor = self._preprocess_video([item], dtype=torch.float32, device=device) |
| | ref_latent = self.vae.encode(ref_tensor).latent_dist.sample() |
| | vae_ref_image.append(ref_latent.to(dtype=dtype)) |
| |
|
| | |
| | if input_video is not None: |
| | input_tensor = self._preprocess_video(input_video, dtype=torch.float32, device=device) |
| | input_latents = self.vae.encode(input_tensor).latent_dist.sample() |
| | input_latents = input_latents.to(dtype=dtype) |
| | latents = self._scheduler_add_noise(input_latents, noise, sigmas, 0) |
| | else: |
| | latents = noise |
| |
|
| | |
| | for step_idx, timestep_val in enumerate(progress_bar(timesteps)): |
| | timestep = timestep_val.unsqueeze(0).to(dtype=dtype, device=device) |
| |
|
| | |
| | noise_pred = self._model_forward( |
| | latents=latents, |
| | timestep=timestep, |
| | context=context, |
| | vae_source_input=vae_source_input, |
| | vae_ref_image=vae_ref_image, |
| | sigmas=sigmas, |
| | timesteps_schedule=timesteps, |
| | ) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | latents = self._scheduler_step(noise_pred, sigmas, step_idx, latents) |
| |
|
| | |
| | if output_type == "latent": |
| | return latents |
| |
|
| | video = self.vae.decode(latents).sample |
| | video = self._vae_output_to_video(video) |
| | return video |
| |
|