Spaces:
Configuration error
Configuration error
Commit ·
3a3a124
1
Parent(s): 5a9bcbe
Try using multiple stages
Browse files
app.py
CHANGED
|
@@ -561,32 +561,159 @@ def decode_with_refdecoder(latents, reference_frame, vae, transformer):
|
|
| 561 |
return video
|
| 562 |
|
| 563 |
|
| 564 |
-
|
| 565 |
-
|
| 566 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 567 |
GENERATION_PIPE.to(DEVICE)
|
| 568 |
-
log_cuda_mem("after pipe -> cuda")
|
| 569 |
-
resized_image, height, width = resize_image_for_wan(image, GENERATION_PIPE)
|
| 570 |
-
generator = torch.Generator(device=DEVICE).manual_seed(seed)
|
| 571 |
try:
|
| 572 |
-
|
| 573 |
-
|
| 574 |
-
|
| 575 |
-
|
| 576 |
-
|
| 577 |
-
|
| 578 |
-
|
| 579 |
-
|
| 580 |
-
|
| 581 |
-
|
| 582 |
-
|
| 583 |
-
|
| 584 |
-
|
| 585 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 586 |
finally:
|
| 587 |
GENERATION_PIPE.to("cpu")
|
| 588 |
-
log_cuda_mem("
|
| 589 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 590 |
|
| 591 |
|
| 592 |
@spaces.GPU(duration=20)
|
|
@@ -635,10 +762,19 @@ def generate_and_decode(image, prompt, seed, progress=gr.Progress()):
|
|
| 635 |
run_dir = OUTPUT_ROOT / f"refdecoder_demo_{uuid.uuid4().hex}"
|
| 636 |
run_dir.mkdir(parents=True, exist_ok=True)
|
| 637 |
|
| 638 |
-
|
|
|
|
| 639 |
|
| 640 |
t0 = time.perf_counter()
|
| 641 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 642 |
latent_secs = time.perf_counter() - t0
|
| 643 |
print(f"[timing] latent generation: {latent_secs:.2f}s")
|
| 644 |
reference_frame = build_reference_frame(resized_image, "cpu")
|
|
|
|
| 561 |
return video
|
| 562 |
|
| 563 |
|
| 564 |
+
CHUNK_BOUNDARIES = (8, 16, 23, NUM_INFERENCE_STEPS)
|
| 565 |
+
assert CHUNK_BOUNDARIES[-1] == NUM_INFERENCE_STEPS
|
| 566 |
+
|
| 567 |
+
|
| 568 |
+
def _run_diffusion_steps(
|
| 569 |
+
latents,
|
| 570 |
+
condition,
|
| 571 |
+
prompt_embeds,
|
| 572 |
+
negative_prompt_embeds,
|
| 573 |
+
image_embeds,
|
| 574 |
+
timesteps,
|
| 575 |
+
start_step,
|
| 576 |
+
end_step,
|
| 577 |
+
transformer_dtype,
|
| 578 |
+
):
|
| 579 |
+
"""Inlined Wan 2.1 I2V denoising loop. Runs steps [start_step, end_step)."""
|
| 580 |
+
transformer = GENERATION_PIPE.transformer
|
| 581 |
+
scheduler = GENERATION_PIPE.scheduler
|
| 582 |
+
with torch.no_grad():
|
| 583 |
+
for i in range(start_step, end_step):
|
| 584 |
+
t = timesteps[i]
|
| 585 |
+
latent_model_input = torch.cat([latents, condition], dim=1).to(transformer_dtype)
|
| 586 |
+
timestep = t.expand(latents.shape[0])
|
| 587 |
+
|
| 588 |
+
with transformer.cache_context("cond"):
|
| 589 |
+
noise_pred = transformer(
|
| 590 |
+
hidden_states=latent_model_input,
|
| 591 |
+
timestep=timestep,
|
| 592 |
+
encoder_hidden_states=prompt_embeds,
|
| 593 |
+
encoder_hidden_states_image=image_embeds,
|
| 594 |
+
return_dict=False,
|
| 595 |
+
)[0]
|
| 596 |
+
with transformer.cache_context("uncond"):
|
| 597 |
+
noise_uncond = transformer(
|
| 598 |
+
hidden_states=latent_model_input,
|
| 599 |
+
timestep=timestep,
|
| 600 |
+
encoder_hidden_states=negative_prompt_embeds,
|
| 601 |
+
encoder_hidden_states_image=image_embeds,
|
| 602 |
+
return_dict=False,
|
| 603 |
+
)[0]
|
| 604 |
+
noise_pred = noise_uncond + GUIDANCE_SCALE * (noise_pred - noise_uncond)
|
| 605 |
+
latents = scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
| 606 |
+
return latents
|
| 607 |
+
|
| 608 |
+
|
| 609 |
+
@spaces.GPU(duration=50)
|
| 610 |
+
def generate_latents_setup_on_gpu(resized_image, prompt, seed, height, width):
|
| 611 |
+
"""Encode prompt+image, prepare latents, run the first chunk of denoising steps.
|
| 612 |
+
|
| 613 |
+
Returns a CPU-resident state dict consumable by generate_latents_chunk_on_gpu.
|
| 614 |
+
"""
|
| 615 |
+
log_cuda_mem("start generate_latents_setup_on_gpu")
|
| 616 |
GENERATION_PIPE.to(DEVICE)
|
|
|
|
|
|
|
|
|
|
| 617 |
try:
|
| 618 |
+
transformer_dtype = GENERATION_PIPE.transformer.dtype
|
| 619 |
+
|
| 620 |
+
prompt_embeds, negative_prompt_embeds = GENERATION_PIPE.encode_prompt(
|
| 621 |
+
prompt=prompt,
|
| 622 |
+
negative_prompt=NEGATIVE_PROMPT,
|
| 623 |
+
do_classifier_free_guidance=True,
|
| 624 |
+
num_videos_per_prompt=1,
|
| 625 |
+
max_sequence_length=512,
|
| 626 |
+
device=DEVICE,
|
| 627 |
+
)
|
| 628 |
+
prompt_embeds = prompt_embeds.to(transformer_dtype)
|
| 629 |
+
negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
|
| 630 |
+
|
| 631 |
+
image_embeds = GENERATION_PIPE.encode_image(resized_image, DEVICE)
|
| 632 |
+
image_embeds = image_embeds.repeat(1, 1, 1).to(transformer_dtype)
|
| 633 |
+
|
| 634 |
+
GENERATION_PIPE.scheduler.set_timesteps(NUM_INFERENCE_STEPS, device=DEVICE)
|
| 635 |
+
timesteps = GENERATION_PIPE.scheduler.timesteps
|
| 636 |
+
|
| 637 |
+
image_tensor = GENERATION_PIPE.video_processor.preprocess(
|
| 638 |
+
resized_image, height=height, width=width
|
| 639 |
+
).to(DEVICE, dtype=torch.float32)
|
| 640 |
+
generator = torch.Generator(device=DEVICE).manual_seed(seed)
|
| 641 |
+
latents, condition = GENERATION_PIPE.prepare_latents(
|
| 642 |
+
image_tensor,
|
| 643 |
+
1,
|
| 644 |
+
GENERATION_PIPE.vae.config.z_dim,
|
| 645 |
+
height,
|
| 646 |
+
width,
|
| 647 |
+
NUM_FRAMES,
|
| 648 |
+
torch.float32,
|
| 649 |
+
DEVICE,
|
| 650 |
+
generator,
|
| 651 |
+
None,
|
| 652 |
+
None,
|
| 653 |
+
)
|
| 654 |
+
|
| 655 |
+
end_step = CHUNK_BOUNDARIES[0]
|
| 656 |
+
latents = _run_diffusion_steps(
|
| 657 |
+
latents,
|
| 658 |
+
condition,
|
| 659 |
+
prompt_embeds,
|
| 660 |
+
negative_prompt_embeds,
|
| 661 |
+
image_embeds,
|
| 662 |
+
timesteps,
|
| 663 |
+
0,
|
| 664 |
+
end_step,
|
| 665 |
+
transformer_dtype,
|
| 666 |
+
)
|
| 667 |
+
|
| 668 |
+
state = {
|
| 669 |
+
"prompt_embeds": prompt_embeds.detach().cpu(),
|
| 670 |
+
"negative_prompt_embeds": negative_prompt_embeds.detach().cpu(),
|
| 671 |
+
"image_embeds": image_embeds.detach().cpu(),
|
| 672 |
+
"condition": condition.detach().cpu(),
|
| 673 |
+
"latents": latents.detach().cpu(),
|
| 674 |
+
"step_idx": end_step,
|
| 675 |
+
}
|
| 676 |
finally:
|
| 677 |
GENERATION_PIPE.to("cpu")
|
| 678 |
+
log_cuda_mem("end generate_latents_setup_on_gpu")
|
| 679 |
+
return state
|
| 680 |
+
|
| 681 |
+
|
| 682 |
+
@spaces.GPU(duration=50)
|
| 683 |
+
def generate_latents_chunk_on_gpu(state, end_step):
|
| 684 |
+
"""Run denoising steps from state['step_idx'] to end_step. Only transformer is moved to GPU."""
|
| 685 |
+
log_cuda_mem(f"start latents chunk -> step {end_step}")
|
| 686 |
+
transformer = GENERATION_PIPE.transformer
|
| 687 |
+
transformer.to(DEVICE)
|
| 688 |
+
try:
|
| 689 |
+
GENERATION_PIPE.scheduler.set_timesteps(NUM_INFERENCE_STEPS, device=DEVICE)
|
| 690 |
+
timesteps = GENERATION_PIPE.scheduler.timesteps
|
| 691 |
+
transformer_dtype = transformer.dtype
|
| 692 |
+
|
| 693 |
+
latents = state["latents"].to(DEVICE)
|
| 694 |
+
condition = state["condition"].to(DEVICE)
|
| 695 |
+
prompt_embeds = state["prompt_embeds"].to(DEVICE)
|
| 696 |
+
negative_prompt_embeds = state["negative_prompt_embeds"].to(DEVICE)
|
| 697 |
+
image_embeds = state["image_embeds"].to(DEVICE)
|
| 698 |
+
|
| 699 |
+
latents = _run_diffusion_steps(
|
| 700 |
+
latents,
|
| 701 |
+
condition,
|
| 702 |
+
prompt_embeds,
|
| 703 |
+
negative_prompt_embeds,
|
| 704 |
+
image_embeds,
|
| 705 |
+
timesteps,
|
| 706 |
+
state["step_idx"],
|
| 707 |
+
end_step,
|
| 708 |
+
transformer_dtype,
|
| 709 |
+
)
|
| 710 |
+
|
| 711 |
+
state["latents"] = latents.detach().cpu()
|
| 712 |
+
state["step_idx"] = end_step
|
| 713 |
+
finally:
|
| 714 |
+
transformer.to("cpu")
|
| 715 |
+
log_cuda_mem(f"end latents chunk -> step {end_step}")
|
| 716 |
+
return state
|
| 717 |
|
| 718 |
|
| 719 |
@spaces.GPU(duration=20)
|
|
|
|
| 762 |
run_dir = OUTPUT_ROOT / f"refdecoder_demo_{uuid.uuid4().hex}"
|
| 763 |
run_dir.mkdir(parents=True, exist_ok=True)
|
| 764 |
|
| 765 |
+
num_chunks = len(CHUNK_BOUNDARIES)
|
| 766 |
+
progress(0.0, desc=f"Generating latents (1/{num_chunks})")
|
| 767 |
|
| 768 |
t0 = time.perf_counter()
|
| 769 |
+
resized_image, height, width = resize_image_for_wan(image, GENERATION_PIPE)
|
| 770 |
+
state = generate_latents_setup_on_gpu(resized_image, prompt, seed, height, width)
|
| 771 |
+
for chunk_idx, end_step in enumerate(CHUNK_BOUNDARIES[1:], start=2):
|
| 772 |
+
progress(
|
| 773 |
+
0.8 * (chunk_idx - 1) / num_chunks,
|
| 774 |
+
desc=f"Generating latents ({chunk_idx}/{num_chunks})",
|
| 775 |
+
)
|
| 776 |
+
state = generate_latents_chunk_on_gpu(state, end_step)
|
| 777 |
+
latents = normalize_latent_shape(state["latents"])
|
| 778 |
latent_secs = time.perf_counter() - t0
|
| 779 |
print(f"[timing] latent generation: {latent_secs:.2f}s")
|
| 780 |
reference_frame = build_reference_frame(resized_image, "cpu")
|