Spaces:
Configuration error
Configuration error
Commit ·
6f279fd
1
Parent(s): 3a3a124
Fixes
Browse files
app.py
CHANGED
|
@@ -113,7 +113,6 @@ def get_module_dtype(module):
|
|
| 113 |
|
| 114 |
|
| 115 |
def load_generation_pipe():
|
| 116 |
-
log_cuda_mem("before load_generation_pipe")
|
| 117 |
image_encoder = CLIPVisionModel.from_pretrained(
|
| 118 |
MODEL_ID,
|
| 119 |
subfolder="image_encoder",
|
|
@@ -130,24 +129,20 @@ def load_generation_pipe():
|
|
| 130 |
image_encoder=image_encoder,
|
| 131 |
torch_dtype=PIPE_DTYPE,
|
| 132 |
)
|
| 133 |
-
log_cuda_mem("after load_generation_pipe")
|
| 134 |
return pipe
|
| 135 |
|
| 136 |
|
| 137 |
def load_wan_vae():
|
| 138 |
-
log_cuda_mem("before load_wan_vae")
|
| 139 |
vae = DiffusersWanVAE.from_pretrained(
|
| 140 |
MODEL_ID,
|
| 141 |
subfolder="vae",
|
| 142 |
torch_dtype=PIPE_DTYPE,
|
| 143 |
)
|
| 144 |
vae.eval()
|
| 145 |
-
log_cuda_mem("after load_wan_vae")
|
| 146 |
return vae
|
| 147 |
|
| 148 |
|
| 149 |
def load_refdecoder_module():
|
| 150 |
-
log_cuda_mem("before load_refdecoder_module")
|
| 151 |
vae = AutoencoderKLWan(
|
| 152 |
dropout_p=0.0,
|
| 153 |
use_reference=True,
|
|
@@ -175,7 +170,6 @@ def load_refdecoder_module():
|
|
| 175 |
vae.load_state_dict(vae_sd, strict=False)
|
| 176 |
transformer.load_state_dict(transformer_sd, strict=False)
|
| 177 |
|
| 178 |
-
log_cuda_mem("after load_refdecoder_module")
|
| 179 |
return vae, transformer
|
| 180 |
|
| 181 |
|
|
@@ -561,7 +555,11 @@ def decode_with_refdecoder(latents, reference_frame, vae, transformer):
|
|
| 561 |
return video
|
| 562 |
|
| 563 |
|
| 564 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 565 |
assert CHUNK_BOUNDARIES[-1] == NUM_INFERENCE_STEPS
|
| 566 |
|
| 567 |
|
|
@@ -608,12 +606,18 @@ def _run_diffusion_steps(
|
|
| 608 |
|
| 609 |
@spaces.GPU(duration=50)
|
| 610 |
def generate_latents_setup_on_gpu(resized_image, prompt, seed, height, width):
|
| 611 |
-
"""Encode prompt+image, prepare latents
|
| 612 |
|
| 613 |
-
|
|
|
|
| 614 |
"""
|
| 615 |
log_cuda_mem("start generate_latents_setup_on_gpu")
|
| 616 |
-
GENERATION_PIPE.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 617 |
try:
|
| 618 |
transformer_dtype = GENERATION_PIPE.transformer.dtype
|
| 619 |
|
|
@@ -631,9 +635,6 @@ def generate_latents_setup_on_gpu(resized_image, prompt, seed, height, width):
|
|
| 631 |
image_embeds = GENERATION_PIPE.encode_image(resized_image, DEVICE)
|
| 632 |
image_embeds = image_embeds.repeat(1, 1, 1).to(transformer_dtype)
|
| 633 |
|
| 634 |
-
GENERATION_PIPE.scheduler.set_timesteps(NUM_INFERENCE_STEPS, device=DEVICE)
|
| 635 |
-
timesteps = GENERATION_PIPE.scheduler.timesteps
|
| 636 |
-
|
| 637 |
image_tensor = GENERATION_PIPE.video_processor.preprocess(
|
| 638 |
resized_image, height=height, width=width
|
| 639 |
).to(DEVICE, dtype=torch.float32)
|
|
@@ -652,29 +653,18 @@ def generate_latents_setup_on_gpu(resized_image, prompt, seed, height, width):
|
|
| 652 |
None,
|
| 653 |
)
|
| 654 |
|
| 655 |
-
end_step = CHUNK_BOUNDARIES[0]
|
| 656 |
-
latents = _run_diffusion_steps(
|
| 657 |
-
latents,
|
| 658 |
-
condition,
|
| 659 |
-
prompt_embeds,
|
| 660 |
-
negative_prompt_embeds,
|
| 661 |
-
image_embeds,
|
| 662 |
-
timesteps,
|
| 663 |
-
0,
|
| 664 |
-
end_step,
|
| 665 |
-
transformer_dtype,
|
| 666 |
-
)
|
| 667 |
-
|
| 668 |
state = {
|
| 669 |
"prompt_embeds": prompt_embeds.detach().cpu(),
|
| 670 |
"negative_prompt_embeds": negative_prompt_embeds.detach().cpu(),
|
| 671 |
"image_embeds": image_embeds.detach().cpu(),
|
| 672 |
"condition": condition.detach().cpu(),
|
| 673 |
"latents": latents.detach().cpu(),
|
| 674 |
-
"step_idx":
|
| 675 |
}
|
| 676 |
finally:
|
| 677 |
-
|
|
|
|
|
|
|
| 678 |
log_cuda_mem("end generate_latents_setup_on_gpu")
|
| 679 |
return state
|
| 680 |
|
|
@@ -762,16 +752,17 @@ def generate_and_decode(image, prompt, seed, progress=gr.Progress()):
|
|
| 762 |
run_dir = OUTPUT_ROOT / f"refdecoder_demo_{uuid.uuid4().hex}"
|
| 763 |
run_dir.mkdir(parents=True, exist_ok=True)
|
| 764 |
|
| 765 |
-
|
| 766 |
-
|
|
|
|
| 767 |
|
| 768 |
t0 = time.perf_counter()
|
| 769 |
resized_image, height, width = resize_image_for_wan(image, GENERATION_PIPE)
|
| 770 |
state = generate_latents_setup_on_gpu(resized_image, prompt, seed, height, width)
|
| 771 |
-
for chunk_idx, end_step in enumerate(CHUNK_BOUNDARIES
|
| 772 |
progress(
|
| 773 |
-
0.8 * (chunk_idx - 1) /
|
| 774 |
-
desc=f"Generating latents ({chunk_idx}/{
|
| 775 |
)
|
| 776 |
state = generate_latents_chunk_on_gpu(state, end_step)
|
| 777 |
latents = normalize_latent_shape(state["latents"])
|
|
|
|
| 113 |
|
| 114 |
|
| 115 |
def load_generation_pipe():
|
|
|
|
| 116 |
image_encoder = CLIPVisionModel.from_pretrained(
|
| 117 |
MODEL_ID,
|
| 118 |
subfolder="image_encoder",
|
|
|
|
| 129 |
image_encoder=image_encoder,
|
| 130 |
torch_dtype=PIPE_DTYPE,
|
| 131 |
)
|
|
|
|
| 132 |
return pipe
|
| 133 |
|
| 134 |
|
| 135 |
def load_wan_vae():
|
|
|
|
| 136 |
vae = DiffusersWanVAE.from_pretrained(
|
| 137 |
MODEL_ID,
|
| 138 |
subfolder="vae",
|
| 139 |
torch_dtype=PIPE_DTYPE,
|
| 140 |
)
|
| 141 |
vae.eval()
|
|
|
|
| 142 |
return vae
|
| 143 |
|
| 144 |
|
| 145 |
def load_refdecoder_module():
|
|
|
|
| 146 |
vae = AutoencoderKLWan(
|
| 147 |
dropout_p=0.0,
|
| 148 |
use_reference=True,
|
|
|
|
| 170 |
vae.load_state_dict(vae_sd, strict=False)
|
| 171 |
transformer.load_state_dict(transformer_sd, strict=False)
|
| 172 |
|
|
|
|
| 173 |
return vae, transformer
|
| 174 |
|
| 175 |
|
|
|
|
| 555 |
return video
|
| 556 |
|
| 557 |
|
| 558 |
+
_NUM_DENOISING_CHUNKS = 4
|
| 559 |
+
CHUNK_BOUNDARIES = tuple(
|
| 560 |
+
NUM_INFERENCE_STEPS * (i + 1) // _NUM_DENOISING_CHUNKS
|
| 561 |
+
for i in range(_NUM_DENOISING_CHUNKS)
|
| 562 |
+
)
|
| 563 |
assert CHUNK_BOUNDARIES[-1] == NUM_INFERENCE_STEPS
|
| 564 |
|
| 565 |
|
|
|
|
| 606 |
|
| 607 |
@spaces.GPU(duration=50)
|
| 608 |
def generate_latents_setup_on_gpu(resized_image, prompt, seed, height, width):
|
| 609 |
+
"""Encode prompt+image, prepare initial latents and condition. NO denoising.
|
| 610 |
|
| 611 |
+
Loads only the encoders + VAE to GPU (not the 14B transformer). Returns a
|
| 612 |
+
CPU-resident state dict consumable by generate_latents_chunk_on_gpu.
|
| 613 |
"""
|
| 614 |
log_cuda_mem("start generate_latents_setup_on_gpu")
|
| 615 |
+
text_encoder = GENERATION_PIPE.text_encoder
|
| 616 |
+
image_encoder = GENERATION_PIPE.image_encoder
|
| 617 |
+
vae = GENERATION_PIPE.vae
|
| 618 |
+
text_encoder.to(DEVICE)
|
| 619 |
+
image_encoder.to(DEVICE)
|
| 620 |
+
vae.to(DEVICE)
|
| 621 |
try:
|
| 622 |
transformer_dtype = GENERATION_PIPE.transformer.dtype
|
| 623 |
|
|
|
|
| 635 |
image_embeds = GENERATION_PIPE.encode_image(resized_image, DEVICE)
|
| 636 |
image_embeds = image_embeds.repeat(1, 1, 1).to(transformer_dtype)
|
| 637 |
|
|
|
|
|
|
|
|
|
|
| 638 |
image_tensor = GENERATION_PIPE.video_processor.preprocess(
|
| 639 |
resized_image, height=height, width=width
|
| 640 |
).to(DEVICE, dtype=torch.float32)
|
|
|
|
| 653 |
None,
|
| 654 |
)
|
| 655 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 656 |
state = {
|
| 657 |
"prompt_embeds": prompt_embeds.detach().cpu(),
|
| 658 |
"negative_prompt_embeds": negative_prompt_embeds.detach().cpu(),
|
| 659 |
"image_embeds": image_embeds.detach().cpu(),
|
| 660 |
"condition": condition.detach().cpu(),
|
| 661 |
"latents": latents.detach().cpu(),
|
| 662 |
+
"step_idx": 0,
|
| 663 |
}
|
| 664 |
finally:
|
| 665 |
+
text_encoder.to("cpu")
|
| 666 |
+
image_encoder.to("cpu")
|
| 667 |
+
vae.to("cpu")
|
| 668 |
log_cuda_mem("end generate_latents_setup_on_gpu")
|
| 669 |
return state
|
| 670 |
|
|
|
|
| 752 |
run_dir = OUTPUT_ROOT / f"refdecoder_demo_{uuid.uuid4().hex}"
|
| 753 |
run_dir.mkdir(parents=True, exist_ok=True)
|
| 754 |
|
| 755 |
+
# 1 setup chunk (encoders + VAE) + len(CHUNK_BOUNDARIES) denoising chunks.
|
| 756 |
+
total_chunks = 1 + len(CHUNK_BOUNDARIES)
|
| 757 |
+
progress(0.0, desc=f"Generating latents (1/{total_chunks})")
|
| 758 |
|
| 759 |
t0 = time.perf_counter()
|
| 760 |
resized_image, height, width = resize_image_for_wan(image, GENERATION_PIPE)
|
| 761 |
state = generate_latents_setup_on_gpu(resized_image, prompt, seed, height, width)
|
| 762 |
+
for chunk_idx, end_step in enumerate(CHUNK_BOUNDARIES, start=2):
|
| 763 |
progress(
|
| 764 |
+
0.8 * (chunk_idx - 1) / total_chunks,
|
| 765 |
+
desc=f"Generating latents ({chunk_idx}/{total_chunks})",
|
| 766 |
)
|
| 767 |
state = generate_latents_chunk_on_gpu(state, end_step)
|
| 768 |
latents = normalize_latent_shape(state["latents"])
|