linyq commited on
Commit
9eccfc5
·
verified ·
1 Parent(s): 82bc279

Update pipeline_kiwi_edit.py

Browse files
Files changed (1) hide show
  1. pipeline_kiwi_edit.py +11 -44
pipeline_kiwi_edit.py CHANGED
@@ -63,10 +63,9 @@ class KiwiEditPipeline(DiffusionPipeline):
63
  mllm_encoder: MLLMEncoder - Qwen2.5-VL MLLM with learnable queries.
64
  processor: AutoProcessor - Qwen2.5-VL processor/tokenizer bundle.
65
  source_embedder: ConditionalEmbedder - VAE source conditioning.
66
- ref_embedder: ConditionalEmbedder - VAE reference conditioning.
67
  """
68
 
69
- model_cpu_offload_seq = "mllm_encoder->source_embedder->ref_embedder->transformer->vae"
70
 
71
  def __init__(
72
  self,
@@ -75,7 +74,6 @@ class KiwiEditPipeline(DiffusionPipeline):
75
  scheduler,
76
  mllm_encoder,
77
  source_embedder,
78
- ref_embedder,
79
  processor=None,
80
  ):
81
  super().__init__()
@@ -89,7 +87,6 @@ class KiwiEditPipeline(DiffusionPipeline):
89
  mllm_encoder=mllm_encoder,
90
  processor=processor,
91
  source_embedder=source_embedder,
92
- ref_embedder=ref_embedder,
93
  )
94
  if processor is not None:
95
  self.mllm_encoder.processor = processor
@@ -229,29 +226,6 @@ class KiwiEditPipeline(DiffusionPipeline):
229
  # --- 3D RoPE frequencies (real-valued cos/sin format) ---
230
  rotary_emb = _build_rope_3d(t.rope, f, h, w, device)
231
 
232
- # --- Reference image conditioning ---
233
- vae_ref_input_length = 0
234
- if vae_ref_image is not None:
235
- if len(vae_ref_image) > 1:
236
- vae_ref = torch.cat(vae_ref_image, dim=2) # concat along temporal
237
- else:
238
- vae_ref = vae_ref_image[0]
239
-
240
- vae_ref = self.ref_embedder(vae_ref)
241
- ref_f, ref_h, ref_w = vae_ref.shape[2:]
242
- vae_ref = rearrange(vae_ref, "b c f h w -> b (f h w) c").contiguous()
243
-
244
- # Recompute RoPE for extended sequence (main + ref tokens)
245
- total_f = f + ref_f
246
- rotary_emb = _build_rope_3d(t.rope, total_f, h, w, device)
247
-
248
- vae_ref_input_length = vae_ref.shape[1]
249
-
250
- if self.ref_embedder.config.ref_pad_first:
251
- x = torch.cat([vae_ref, x], dim=1)
252
- else:
253
- x = torch.cat([x, vae_ref], dim=1)
254
-
255
  # --- Transformer blocks ---
256
  for block in t.blocks:
257
  x = block(x, context, t_mod, rotary_emb)
@@ -267,13 +241,6 @@ class KiwiEditPipeline(DiffusionPipeline):
267
  x = (t.norm_out(x.float()) * (1 + scale) + shift).type_as(x)
268
  x = t.proj_out(x)
269
 
270
- # --- Remove ref tokens from output ---
271
- if vae_ref_image is not None and vae_ref_input_length > 0:
272
- if self.ref_embedder.config.ref_pad_first:
273
- x = x[:, vae_ref_input_length:, :]
274
- else:
275
- x = x[:, :-vae_ref_input_length, :]
276
-
277
  # --- Unpatchify ---
278
  patch_size = t.config.patch_size
279
  x = rearrange(
@@ -410,15 +377,15 @@ class KiwiEditPipeline(DiffusionPipeline):
410
  vae_source_input = vae_source_input.to(dtype=dtype)
411
 
412
  # --- 7. Encode reference images ---
413
- vae_ref_image = None
414
- if ref_image is not None:
415
- vae_ref_image = []
416
- for item in ref_image:
417
- target_size = (width, height)
418
- item = ImageOps.pad(item, target_size, color="white", centering=(0.5, 0.5))
419
- ref_tensor = self._preprocess_video([item], dtype=torch.float32, device=device)
420
- ref_latent = self.vae.encode(ref_tensor).latent_dist.sample()
421
- vae_ref_image.append(ref_latent.to(dtype=dtype))
422
 
423
  # --- 8. Handle input_video (video-to-video) ---
424
  if input_video is not None:
@@ -439,7 +406,7 @@ class KiwiEditPipeline(DiffusionPipeline):
439
  timestep=timestep,
440
  context=context,
441
  vae_source_input=vae_source_input,
442
- vae_ref_image=vae_ref_image,
443
  sigmas=sigmas,
444
  timesteps_schedule=timesteps,
445
  )
 
63
  mllm_encoder: MLLMEncoder - Qwen2.5-VL MLLM with learnable queries.
64
  processor: AutoProcessor - Qwen2.5-VL processor/tokenizer bundle.
65
  source_embedder: ConditionalEmbedder - VAE source conditioning.
 
66
  """
67
 
68
+ model_cpu_offload_seq = "mllm_encoder->source_embedder->transformer->vae"
69
 
70
  def __init__(
71
  self,
 
74
  scheduler,
75
  mllm_encoder,
76
  source_embedder,
 
77
  processor=None,
78
  ):
79
  super().__init__()
 
87
  mllm_encoder=mllm_encoder,
88
  processor=processor,
89
  source_embedder=source_embedder,
 
90
  )
91
  if processor is not None:
92
  self.mllm_encoder.processor = processor
 
226
  # --- 3D RoPE frequencies (real-valued cos/sin format) ---
227
  rotary_emb = _build_rope_3d(t.rope, f, h, w, device)
228
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
  # --- Transformer blocks ---
230
  for block in t.blocks:
231
  x = block(x, context, t_mod, rotary_emb)
 
241
  x = (t.norm_out(x.float()) * (1 + scale) + shift).type_as(x)
242
  x = t.proj_out(x)
243
 
 
 
 
 
 
 
 
244
  # --- Unpatchify ---
245
  patch_size = t.config.patch_size
246
  x = rearrange(
 
377
  vae_source_input = vae_source_input.to(dtype=dtype)
378
 
379
  # --- 7. Encode reference images ---
380
+ # vae_ref_image = None
381
+ # if ref_image is not None:
382
+ # vae_ref_image = []
383
+ # for item in ref_image:
384
+ # target_size = (width, height)
385
+ # item = ImageOps.pad(item, target_size, color="white", centering=(0.5, 0.5))
386
+ # ref_tensor = self._preprocess_video([item], dtype=torch.float32, device=device)
387
+ # ref_latent = self.vae.encode(ref_tensor).latent_dist.sample()
388
+ # vae_ref_image.append(ref_latent.to(dtype=dtype))
389
 
390
  # --- 8. Handle input_video (video-to-video) ---
391
  if input_video is not None:
 
406
  timestep=timestep,
407
  context=context,
408
  vae_source_input=vae_source_input,
409
+ # vae_ref_image=vae_ref_image,
410
  sigmas=sigmas,
411
  timesteps_schedule=timesteps,
412
  )