File size: 19,664 Bytes
ca5da2f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
import torch
import torch.nn.functional as F
import numpy as np
from typing import Optional, List, Union, Callable, Tuple
from PIL import Image, ImageOps
from einops import rearrange
from tqdm import tqdm
from diffusers import DiffusionPipeline


def sinusoidal_embedding_1d(dim, position):
    """1D sinusoidal positional embedding for timesteps."""
    sinusoid = torch.outer(
        position.type(torch.float64),
        torch.pow(
            10000,
            -torch.arange(dim // 2, dtype=torch.float64, device=position.device).div(
                dim // 2
            ),
        ),
    )
    x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
    return x.to(position.dtype)


def _build_rope_3d(rope_module, f, h, w, device):
    """
    Build 3D RoPE (cos, sin) for a given (f, h, w) grid using the
    WanRotaryPosEmbed module's precomputed buffers.

    Returns:
        (freqs_cos, freqs_sin) each of shape [1, f*h*w, 1, head_dim]
    """
    split_sizes = [rope_module.t_dim, rope_module.h_dim, rope_module.w_dim]
    cos_parts = rope_module.freqs_cos.split(split_sizes, dim=1)
    sin_parts = rope_module.freqs_sin.split(split_sizes, dim=1)

    cos_f = cos_parts[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1)
    cos_h = cos_parts[1][:h].view(1, h, 1, -1).expand(f, h, w, -1)
    cos_w = cos_parts[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)

    sin_f = sin_parts[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1)
    sin_h = sin_parts[1][:h].view(1, h, 1, -1).expand(f, h, w, -1)
    sin_w = sin_parts[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)

    freqs_cos = torch.cat([cos_f, cos_h, cos_w], dim=-1).reshape(1, f * h * w, 1, -1).to(device)
    freqs_sin = torch.cat([sin_f, sin_h, sin_w], dim=-1).reshape(1, f * h * w, 1, -1).to(device)
    return freqs_cos, freqs_sin


class KiwiEditPipeline(DiffusionPipeline):
    """
    Pipeline for reference-guided video and image editing using KiwiEdit.

    This pipeline uses a Qwen2.5-VL multimodal LLM encoder for understanding
    editing instructions with source visual context, a WanTransformer3DModel
    for diffusion, and AutoencoderKLWan for VAE encoding/decoding.

    Args:
        transformer: WanTransformer3DModel - DiT backbone for denoising.
        vae: AutoencoderKLWan - 3D causal VAE.
        scheduler: FlowMatchEulerDiscreteScheduler or compatible scheduler.
        mllm_encoder: MLLMEncoder - Qwen2.5-VL MLLM with learnable queries.
        processor: AutoProcessor - Qwen2.5-VL processor/tokenizer bundle.
        source_embedder: ConditionalEmbedder - VAE source conditioning.
        ref_embedder: ConditionalEmbedder - VAE reference conditioning.
    """

    model_cpu_offload_seq = "mllm_encoder->source_embedder->ref_embedder->transformer->vae"

    def __init__(
        self,
        transformer,
        vae,
        scheduler,
        mllm_encoder,
        source_embedder,
        ref_embedder,
        processor=None,
    ):
        super().__init__()
        if isinstance(processor, (list, tuple)):
            # Diffusers may pass the raw model_index spec; let MLLMEncoder resolve it later.
            processor = None
        self.register_modules(
            transformer=transformer,
            vae=vae,
            scheduler=scheduler,
            mllm_encoder=mllm_encoder,
            processor=processor,
            source_embedder=source_embedder,
            ref_embedder=ref_embedder,
        )
        if processor is not None:
            self.mllm_encoder.processor = processor

    # ------------------------------------------------------------------ #
    #                        Helper utilities                             #
    # ------------------------------------------------------------------ #

    @staticmethod
    def _check_resize(height, width, num_frames, h_div=16, w_div=16, t_div=4, t_rem=1):
        """Round height/width/num_frames to valid values."""
        if height % h_div != 0:
            height = (height + h_div - 1) // h_div * h_div
        if width % w_div != 0:
            width = (width + w_div - 1) // w_div * w_div
        if num_frames % t_div != t_rem:
            num_frames = (num_frames + t_div - 1) // t_div * t_div + t_rem
        return height, width, num_frames

    @staticmethod
    def _preprocess_image(image: Image.Image, dtype, device):
        """Convert PIL Image to tensor in [-1, 1]."""
        arr = np.array(image, dtype=np.float32)
        tensor = torch.from_numpy(arr).to(dtype=dtype, device=device)
        tensor = tensor / 127.5 - 1.0  # [0, 255] -> [-1, 1]
        tensor = tensor.permute(2, 0, 1)  # H W C -> C H W
        return tensor

    def _preprocess_video(self, frames: List[Image.Image], dtype, device):
        """Convert list of PIL Images to tensor [1, C, T, H, W] in [-1, 1]."""
        tensors = [self._preprocess_image(f, dtype, device) for f in frames]
        video = torch.stack(tensors, dim=1)  # C T H W
        return video.unsqueeze(0)  # 1 C T H W

    @staticmethod
    def _vae_output_to_video(vae_output):
        """Convert VAE output tensor to list of PIL Images."""
        # vae_output shape: [B, C, T, H, W] or [T, H, W, C]
        if vae_output.dim() == 5:
            vae_output = vae_output.squeeze(0).permute(1, 2, 3, 0)  # T H W C
        frames = []
        for t in range(vae_output.shape[0]):
            frame = ((vae_output[t] + 1.0) * 127.5).clamp(0, 255)
            frame = frame.to(device="cpu", dtype=torch.uint8).numpy()
            frames.append(Image.fromarray(frame))
        return frames

    # ------------------------------------------------------------------ #
    #                   Custom Flow Match Scheduler                       #
    # ------------------------------------------------------------------ #

    def _setup_scheduler(self, num_inference_steps, denoising_strength=1.0, shift=5.0):
        """
        Set up flow-match sigmas and timesteps matching the original diffsynth
        FlowMatchScheduler with extra_one_step=True and shift.
        """
        sigma_min = 0.003 / 1.002
        sigma_max = 1.0
        sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength
        # extra_one_step: generate N+1 points, drop last
        sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps + 1)[:-1]
        # Apply shift
        sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
        timesteps = sigmas * 1000  # num_train_timesteps = 1000
        return sigmas, timesteps

    def _scheduler_step(self, model_output, sigmas, step_index, sample):
        """Euler step for flow matching."""
        sigma = sigmas[step_index]
        if step_index + 1 >= len(sigmas):
            sigma_next = 0.0
        else:
            sigma_next = sigmas[step_index + 1]
        return sample + model_output * (sigma_next - sigma)

    def _scheduler_add_noise(self, original_samples, noise, sigmas, step_index):
        """Add noise at given timestep for img2img / video2video."""
        sigma = sigmas[step_index]
        return (1 - sigma) * original_samples + sigma * noise

    def _scheduler_get_sigma(self, timestep, sigmas, timesteps):
        """Get sigma for a given timestep."""
        timestep_id = torch.argmin((timesteps - timestep).abs())
        return sigmas[timestep_id]

    # ------------------------------------------------------------------ #
    #                    Transformer forward helpers                      #
    # ------------------------------------------------------------------ #

    def _model_forward(
        self,
        latents,
        timestep,
        context,
        vae_source_input=None,
        vae_ref_image=None,
        sigmas=None,
        timesteps_schedule=None,
    ):
        """
        Custom DiT forward pass that handles source/ref conditioning.
        Mirrors model_fn_wan_video from the original diffsynth pipeline.
        """
        device = latents.device
        dtype = latents.dtype
        t = self.transformer

        # --- Timestep embedding ---
        timestep_emb = sinusoidal_embedding_1d(
            t.config.freq_dim, timestep
        ).to(dtype)
        time_emb = t.condition_embedder.time_embedder(timestep_emb)
        # diffusers time_proj = Linear only (SiLU is applied separately)
        t_mod = t.condition_embedder.time_proj(F.silu(time_emb)).unflatten(
            1, (6, t.config.num_attention_heads * t.config.attention_head_dim)
        )

        # --- Text/context embedding ---
        # NOTE: Do NOT apply text_embedder here. The MLLM encoder's connector
        # already projects to dit_dim. text_embedder is for raw text encoder
        # output (text_dim → dim), which doesn't apply to MLLM output.

        # --- Patchify latents ---
        x = latents
        if vae_source_input is not None:
            vae_source_cond = self.source_embedder(vae_source_input)
            x = t.patch_embedding(x)
            # Get sigma for this timestep
            sigma = self._scheduler_get_sigma(timestep, sigmas, timesteps_schedule)
            x = x + vae_source_cond * sigma
        else:
            x = t.patch_embedding(x)

        f, h, w = x.shape[2:]
        x = rearrange(x, "b c f h w -> b (f h w) c").contiguous()

        # --- 3D RoPE frequencies (real-valued cos/sin format) ---
        rotary_emb = _build_rope_3d(t.rope, f, h, w, device)

        # --- Reference image conditioning ---
        vae_ref_input_length = 0
        if vae_ref_image is not None:
            if len(vae_ref_image) > 1:
                vae_ref = torch.cat(vae_ref_image, dim=2)  # concat along temporal
            else:
                vae_ref = vae_ref_image[0]

            vae_ref = self.ref_embedder(vae_ref)
            ref_f, ref_h, ref_w = vae_ref.shape[2:]
            vae_ref = rearrange(vae_ref, "b c f h w -> b (f h w) c").contiguous()

            # Recompute RoPE for extended sequence (main + ref tokens)
            total_f = f + ref_f
            rotary_emb = _build_rope_3d(t.rope, total_f, h, w, device)

            vae_ref_input_length = vae_ref.shape[1]

            if self.ref_embedder.config.ref_pad_first:
                x = torch.cat([vae_ref, x], dim=1)
            else:
                x = torch.cat([x, vae_ref], dim=1)

        # --- Transformer blocks ---
        for block in t.blocks:
            x = block(x, context, t_mod, rotary_emb)

        # --- Output head ---
        # Match diffusers' FP32 norm + modulation + projection
        table = t.scale_shift_table
        shift, scale = (
            table.to(device=device) + time_emb.unsqueeze(1)
        ).chunk(2, dim=1)
        shift = shift.to(device=x.device)
        scale = scale.to(device=x.device)
        x = (t.norm_out(x.float()) * (1 + scale) + shift).type_as(x)
        x = t.proj_out(x)

        # --- Remove ref tokens from output ---
        if vae_ref_image is not None and vae_ref_input_length > 0:
            if self.ref_embedder.config.ref_pad_first:
                x = x[:, vae_ref_input_length:, :]
            else:
                x = x[:, :-vae_ref_input_length, :]

        # --- Unpatchify ---
        patch_size = t.config.patch_size
        x = rearrange(
            x,
            "b (f h w) (x y z c) -> b c (f x) (h y) (w z)",
            f=f, h=h, w=w,
            x=patch_size[0], y=patch_size[1], z=patch_size[2],
        )
        return x

    # ------------------------------------------------------------------ #
    #                          Main __call__                              #
    # ------------------------------------------------------------------ #

    @torch.no_grad()
    def __call__(
        self,
        prompt: str,
        source_video: Optional[List[Image.Image]] = None,
        source_input: Optional[List[Image.Image]] = None,
        ref_image: Optional[List[Image.Image]] = None,
        negative_prompt: Optional[str] = "",
        input_video: Optional[List[Image.Image]] = None,
        height: int = 480,
        width: int = 832,
        num_frames: int = 81,
        num_inference_steps: int = 50,
        guidance_scale: float = 1.0,
        sigma_shift: float = 5.0,
        denoising_strength: float = 1.0,
        seed: Optional[int] = None,
        tiled: bool = True,
        tile_size: Tuple[int, int] = (30, 52),
        tile_stride: Tuple[int, int] = (15, 26),
        output_type: str = "pil",
        progress_bar: Callable = tqdm,
    ) -> List[Image.Image]:
        """
        Run KiwiEdit inference.

        Args:
            prompt: Editing instruction text.
            source_video: Source video/image frames for MLLM context (also used as
                source_input if source_input is not provided).
            source_input: Source frames for VAE conditioning. If None but source_video
                is provided, source_video is used.
            ref_image: Optional reference image(s) for guided editing.
            negative_prompt: Negative prompt for CFG.
            input_video: Optional input video for video-to-video (adds noise then denoises).
            height: Output height in pixels.
            width: Output width in pixels.
            num_frames: Number of output frames (1 for image editing).
            num_inference_steps: Number of denoising steps.
            guidance_scale: Classifier-free guidance scale.
            sigma_shift: Flow matching shift parameter.
            denoising_strength: How much noise to add (1.0 = full noise).
            seed: Random seed for reproducibility.
            tiled: Whether to use tiled VAE encoding/decoding.
            tile_size: VAE tile size.
            tile_stride: VAE tile stride.
            output_type: "pil" for PIL Images, "latent" for raw latents.
            progress_bar: Progress bar callable (e.g., tqdm).

        Returns:
            List of PIL Images (video frames).
        """
        device = self._execution_device
        dtype = torch.bfloat16
        # --- 1. Shape check ---
        # VAE spatial factor is 16, transformer patch spatial is 2,
        # so pixel dims must be multiples of 32.
        height, width, num_frames = self._check_resize(
            height, width, num_frames, h_div=32, w_div=32
        )
        
        # --- 2. Determine VAE parameters ---
        z_dim = self.vae.config.z_dim
        # Compute upsampling factor from VAE config
        dim_mult = self.vae.config.get("dim_mult", [1, 2, 4, 4])
        temporal_downsample = self.vae.config.get("temperal_downsample", [False, True, True])
        # Wan VideoVAE spatial factor is 2^(len(dim_mult)) due to extra
        # downsampling in the encoder beyond the level transitions.
        spatial_factor = 2 ** len(dim_mult)  # 16 for 4 levels
        temporal_factor = 2 ** sum(temporal_downsample)  # 4 for [F, T, T]

        # --- 3. MLLM encoding ---
        context = None
        src_video_for_mllm = source_video
        if src_video_for_mllm is not None:
            self.mllm_encoder._ensure_qwen_loaded()
            if ref_image is not None:
                # Ref mode always uses the video path (even for a single frame)
                context = self.mllm_encoder(
                    prompt, src_video=src_video_for_mllm, ref_image=ref_image
                )
            elif len(src_video_for_mllm) == 1:
                context = self.mllm_encoder(
                    prompt, src_image=src_video_for_mllm
                )
            else:
                context = self.mllm_encoder(
                    prompt, src_video=src_video_for_mllm
                )
        # For negative prompt: use zero context
        context_nega = None

        # --- 4. Setup scheduler ---
        sigmas, timesteps = self._setup_scheduler(
            num_inference_steps, denoising_strength, sigma_shift
        )
        sigmas = sigmas.to(device)
        timesteps = timesteps.to(device)

        # --- 5. Initialize noise ---
        latent_length = (num_frames - 1) // temporal_factor + 1
        latent_h = height // spatial_factor
        latent_w = width // spatial_factor
        shape = (1, z_dim, latent_length, latent_h, latent_w)

        generator = None if seed is None else torch.Generator("cpu").manual_seed(seed)
        noise = torch.randn(shape, generator=generator, device="cpu", dtype=torch.float32)
        noise = noise.to(dtype=dtype, device=device)

        # --- 6. Encode source input ---
        vae_source_input = None
        # Fall back to source_video if source_input not provided
        src_for_vae = source_input if source_input is not None else source_video
        if src_for_vae is not None:
            src_frames = [src_for_vae[i] for i in range(min(num_frames, len(src_for_vae)))]
            # Resize source frames to match the (possibly adjusted) target dimensions
            src_frames = [f.resize((width, height), Image.LANCZOS) for f in src_frames]
            src_tensor = self._preprocess_video(src_frames, dtype=torch.float32, device=device)
            vae_source_input = self.vae.encode(src_tensor).latent_dist.sample()
            vae_source_input = vae_source_input.to(dtype=dtype)

        # --- 7. Encode reference images ---
        vae_ref_image = None
        if ref_image is not None:
            vae_ref_image = []
            for item in ref_image:
                target_size = (width, height)
                item = ImageOps.pad(item, target_size, color="white", centering=(0.5, 0.5))
                ref_tensor = self._preprocess_video([item], dtype=torch.float32, device=device)
                ref_latent = self.vae.encode(ref_tensor).latent_dist.sample()
                vae_ref_image.append(ref_latent.to(dtype=dtype))

        # --- 8. Handle input_video (video-to-video) ---
        if input_video is not None:
            input_tensor = self._preprocess_video(input_video, dtype=torch.float32, device=device)
            input_latents = self.vae.encode(input_tensor).latent_dist.sample()
            input_latents = input_latents.to(dtype=dtype)
            latents = self._scheduler_add_noise(input_latents, noise, sigmas, 0)
        else:
            latents = noise

        # --- 9. Denoising loop ---
        for step_idx, timestep_val in enumerate(progress_bar(timesteps)):
            timestep = timestep_val.unsqueeze(0).to(dtype=dtype, device=device)

            # Positive prediction
            noise_pred = self._model_forward(
                latents=latents,
                timestep=timestep,
                context=context,
                vae_source_input=vae_source_input,
                vae_ref_image=vae_ref_image,
                sigmas=sigmas,
                timesteps_schedule=timesteps,
            )

            # CFG
            # if guidance_scale != 1.0:
            #     noise_pred_nega = self._model_forward(
            #         latents=latents,
            #         timestep=timestep,
            #         context=context_nega,
            #         vae_source_input=vae_source_input,
            #         vae_ref_image=vae_ref_image,
            #         sigmas=sigmas,
            #         timesteps_schedule=timesteps,
            #     )
            #     noise_pred = noise_pred_nega + guidance_scale * (
            #         noise_pred_posi - noise_pred_nega
            #     )
            # else:
            #     noise_pred = noise_pred_posi

            # Scheduler step
            latents = self._scheduler_step(noise_pred, sigmas, step_idx, latents)

        # --- 10. Decode ---
        if output_type == "latent":
            return latents

        video = self.vae.decode(latents).sample
        video = self._vae_output_to_video(video)
        return video