Update pipeline_kiwi_edit.py
Browse files- pipeline_kiwi_edit.py +11 -44
pipeline_kiwi_edit.py
CHANGED
|
@@ -63,10 +63,9 @@ class KiwiEditPipeline(DiffusionPipeline):
|
|
| 63 |
mllm_encoder: MLLMEncoder - Qwen2.5-VL MLLM with learnable queries.
|
| 64 |
processor: AutoProcessor - Qwen2.5-VL processor/tokenizer bundle.
|
| 65 |
source_embedder: ConditionalEmbedder - VAE source conditioning.
|
| 66 |
-
ref_embedder: ConditionalEmbedder - VAE reference conditioning.
|
| 67 |
"""
|
| 68 |
|
| 69 |
-
model_cpu_offload_seq = "mllm_encoder->source_embedder->
|
| 70 |
|
| 71 |
def __init__(
|
| 72 |
self,
|
|
@@ -75,7 +74,6 @@ class KiwiEditPipeline(DiffusionPipeline):
|
|
| 75 |
scheduler,
|
| 76 |
mllm_encoder,
|
| 77 |
source_embedder,
|
| 78 |
-
ref_embedder,
|
| 79 |
processor=None,
|
| 80 |
):
|
| 81 |
super().__init__()
|
|
@@ -89,7 +87,6 @@ class KiwiEditPipeline(DiffusionPipeline):
|
|
| 89 |
mllm_encoder=mllm_encoder,
|
| 90 |
processor=processor,
|
| 91 |
source_embedder=source_embedder,
|
| 92 |
-
ref_embedder=ref_embedder,
|
| 93 |
)
|
| 94 |
if processor is not None:
|
| 95 |
self.mllm_encoder.processor = processor
|
|
@@ -229,29 +226,6 @@ class KiwiEditPipeline(DiffusionPipeline):
|
|
| 229 |
# --- 3D RoPE frequencies (real-valued cos/sin format) ---
|
| 230 |
rotary_emb = _build_rope_3d(t.rope, f, h, w, device)
|
| 231 |
|
| 232 |
-
# --- Reference image conditioning ---
|
| 233 |
-
vae_ref_input_length = 0
|
| 234 |
-
if vae_ref_image is not None:
|
| 235 |
-
if len(vae_ref_image) > 1:
|
| 236 |
-
vae_ref = torch.cat(vae_ref_image, dim=2) # concat along temporal
|
| 237 |
-
else:
|
| 238 |
-
vae_ref = vae_ref_image[0]
|
| 239 |
-
|
| 240 |
-
vae_ref = self.ref_embedder(vae_ref)
|
| 241 |
-
ref_f, ref_h, ref_w = vae_ref.shape[2:]
|
| 242 |
-
vae_ref = rearrange(vae_ref, "b c f h w -> b (f h w) c").contiguous()
|
| 243 |
-
|
| 244 |
-
# Recompute RoPE for extended sequence (main + ref tokens)
|
| 245 |
-
total_f = f + ref_f
|
| 246 |
-
rotary_emb = _build_rope_3d(t.rope, total_f, h, w, device)
|
| 247 |
-
|
| 248 |
-
vae_ref_input_length = vae_ref.shape[1]
|
| 249 |
-
|
| 250 |
-
if self.ref_embedder.config.ref_pad_first:
|
| 251 |
-
x = torch.cat([vae_ref, x], dim=1)
|
| 252 |
-
else:
|
| 253 |
-
x = torch.cat([x, vae_ref], dim=1)
|
| 254 |
-
|
| 255 |
# --- Transformer blocks ---
|
| 256 |
for block in t.blocks:
|
| 257 |
x = block(x, context, t_mod, rotary_emb)
|
|
@@ -267,13 +241,6 @@ class KiwiEditPipeline(DiffusionPipeline):
|
|
| 267 |
x = (t.norm_out(x.float()) * (1 + scale) + shift).type_as(x)
|
| 268 |
x = t.proj_out(x)
|
| 269 |
|
| 270 |
-
# --- Remove ref tokens from output ---
|
| 271 |
-
if vae_ref_image is not None and vae_ref_input_length > 0:
|
| 272 |
-
if self.ref_embedder.config.ref_pad_first:
|
| 273 |
-
x = x[:, vae_ref_input_length:, :]
|
| 274 |
-
else:
|
| 275 |
-
x = x[:, :-vae_ref_input_length, :]
|
| 276 |
-
|
| 277 |
# --- Unpatchify ---
|
| 278 |
patch_size = t.config.patch_size
|
| 279 |
x = rearrange(
|
|
@@ -410,15 +377,15 @@ class KiwiEditPipeline(DiffusionPipeline):
|
|
| 410 |
vae_source_input = vae_source_input.to(dtype=dtype)
|
| 411 |
|
| 412 |
# --- 7. Encode reference images ---
|
| 413 |
-
vae_ref_image = None
|
| 414 |
-
if ref_image is not None:
|
| 415 |
-
|
| 416 |
-
|
| 417 |
-
|
| 418 |
-
|
| 419 |
-
|
| 420 |
-
|
| 421 |
-
|
| 422 |
|
| 423 |
# --- 8. Handle input_video (video-to-video) ---
|
| 424 |
if input_video is not None:
|
|
@@ -439,7 +406,7 @@ class KiwiEditPipeline(DiffusionPipeline):
|
|
| 439 |
timestep=timestep,
|
| 440 |
context=context,
|
| 441 |
vae_source_input=vae_source_input,
|
| 442 |
-
vae_ref_image=vae_ref_image,
|
| 443 |
sigmas=sigmas,
|
| 444 |
timesteps_schedule=timesteps,
|
| 445 |
)
|
|
|
|
| 63 |
mllm_encoder: MLLMEncoder - Qwen2.5-VL MLLM with learnable queries.
|
| 64 |
processor: AutoProcessor - Qwen2.5-VL processor/tokenizer bundle.
|
| 65 |
source_embedder: ConditionalEmbedder - VAE source conditioning.
|
|
|
|
| 66 |
"""
|
| 67 |
|
| 68 |
+
model_cpu_offload_seq = "mllm_encoder->source_embedder->transformer->vae"
|
| 69 |
|
| 70 |
def __init__(
|
| 71 |
self,
|
|
|
|
| 74 |
scheduler,
|
| 75 |
mllm_encoder,
|
| 76 |
source_embedder,
|
|
|
|
| 77 |
processor=None,
|
| 78 |
):
|
| 79 |
super().__init__()
|
|
|
|
| 87 |
mllm_encoder=mllm_encoder,
|
| 88 |
processor=processor,
|
| 89 |
source_embedder=source_embedder,
|
|
|
|
| 90 |
)
|
| 91 |
if processor is not None:
|
| 92 |
self.mllm_encoder.processor = processor
|
|
|
|
| 226 |
# --- 3D RoPE frequencies (real-valued cos/sin format) ---
|
| 227 |
rotary_emb = _build_rope_3d(t.rope, f, h, w, device)
|
| 228 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 229 |
# --- Transformer blocks ---
|
| 230 |
for block in t.blocks:
|
| 231 |
x = block(x, context, t_mod, rotary_emb)
|
|
|
|
| 241 |
x = (t.norm_out(x.float()) * (1 + scale) + shift).type_as(x)
|
| 242 |
x = t.proj_out(x)
|
| 243 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 244 |
# --- Unpatchify ---
|
| 245 |
patch_size = t.config.patch_size
|
| 246 |
x = rearrange(
|
|
|
|
| 377 |
vae_source_input = vae_source_input.to(dtype=dtype)
|
| 378 |
|
| 379 |
# --- 7. Encode reference images ---
|
| 380 |
+
# vae_ref_image = None
|
| 381 |
+
# if ref_image is not None:
|
| 382 |
+
# vae_ref_image = []
|
| 383 |
+
# for item in ref_image:
|
| 384 |
+
# target_size = (width, height)
|
| 385 |
+
# item = ImageOps.pad(item, target_size, color="white", centering=(0.5, 0.5))
|
| 386 |
+
# ref_tensor = self._preprocess_video([item], dtype=torch.float32, device=device)
|
| 387 |
+
# ref_latent = self.vae.encode(ref_tensor).latent_dist.sample()
|
| 388 |
+
# vae_ref_image.append(ref_latent.to(dtype=dtype))
|
| 389 |
|
| 390 |
# --- 8. Handle input_video (video-to-video) ---
|
| 391 |
if input_video is not None:
|
|
|
|
| 406 |
timestep=timestep,
|
| 407 |
context=context,
|
| 408 |
vae_source_input=vae_source_input,
|
| 409 |
+
# vae_ref_image=vae_ref_image,
|
| 410 |
sigmas=sigmas,
|
| 411 |
timesteps_schedule=timesteps,
|
| 412 |
)
|