xiangfan00 commited on
Commit
3a3a124
·
1 Parent(s): 5a9bcbe

Try using multiple stages

Browse files
Files changed (1) hide show
  1. app.py +160 -24
app.py CHANGED
@@ -561,32 +561,159 @@ def decode_with_refdecoder(latents, reference_frame, vae, transformer):
561
  return video
562
 
563
 
564
- @spaces.GPU(duration=160)
565
- def generate_latents_on_gpu(image, prompt, seed):
566
- log_cuda_mem("start generate_latents_on_gpu")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
567
  GENERATION_PIPE.to(DEVICE)
568
- log_cuda_mem("after pipe -> cuda")
569
- resized_image, height, width = resize_image_for_wan(image, GENERATION_PIPE)
570
- generator = torch.Generator(device=DEVICE).manual_seed(seed)
571
  try:
572
- with torch.no_grad():
573
- output = GENERATION_PIPE(
574
- image=resized_image,
575
- prompt=prompt,
576
- negative_prompt=NEGATIVE_PROMPT,
577
- height=height,
578
- width=width,
579
- num_frames=NUM_FRAMES,
580
- num_inference_steps=NUM_INFERENCE_STEPS,
581
- guidance_scale=GUIDANCE_SCALE,
582
- generator=generator,
583
- output_type="latent",
584
- )
585
- latents = normalize_latent_shape(output.frames).detach().cpu()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
586
  finally:
587
  GENERATION_PIPE.to("cpu")
588
- log_cuda_mem("after latent generation")
589
- return latents, resized_image, height, width
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
590
 
591
 
592
  @spaces.GPU(duration=20)
@@ -635,10 +762,19 @@ def generate_and_decode(image, prompt, seed, progress=gr.Progress()):
635
  run_dir = OUTPUT_ROOT / f"refdecoder_demo_{uuid.uuid4().hex}"
636
  run_dir.mkdir(parents=True, exist_ok=True)
637
 
638
- progress(0.0, desc="Generating latents")
 
639
 
640
  t0 = time.perf_counter()
641
- latents, resized_image, height, width = generate_latents_on_gpu(image, prompt, seed)
 
 
 
 
 
 
 
 
642
  latent_secs = time.perf_counter() - t0
643
  print(f"[timing] latent generation: {latent_secs:.2f}s")
644
  reference_frame = build_reference_frame(resized_image, "cpu")
 
561
  return video
562
 
563
 
564
+ CHUNK_BOUNDARIES = (8, 16, 23, NUM_INFERENCE_STEPS)
565
+ assert CHUNK_BOUNDARIES[-1] == NUM_INFERENCE_STEPS
566
+
567
+
568
+ def _run_diffusion_steps(
569
+ latents,
570
+ condition,
571
+ prompt_embeds,
572
+ negative_prompt_embeds,
573
+ image_embeds,
574
+ timesteps,
575
+ start_step,
576
+ end_step,
577
+ transformer_dtype,
578
+ ):
579
+ """Inlined Wan 2.1 I2V denoising loop. Runs steps [start_step, end_step)."""
580
+ transformer = GENERATION_PIPE.transformer
581
+ scheduler = GENERATION_PIPE.scheduler
582
+ with torch.no_grad():
583
+ for i in range(start_step, end_step):
584
+ t = timesteps[i]
585
+ latent_model_input = torch.cat([latents, condition], dim=1).to(transformer_dtype)
586
+ timestep = t.expand(latents.shape[0])
587
+
588
+ with transformer.cache_context("cond"):
589
+ noise_pred = transformer(
590
+ hidden_states=latent_model_input,
591
+ timestep=timestep,
592
+ encoder_hidden_states=prompt_embeds,
593
+ encoder_hidden_states_image=image_embeds,
594
+ return_dict=False,
595
+ )[0]
596
+ with transformer.cache_context("uncond"):
597
+ noise_uncond = transformer(
598
+ hidden_states=latent_model_input,
599
+ timestep=timestep,
600
+ encoder_hidden_states=negative_prompt_embeds,
601
+ encoder_hidden_states_image=image_embeds,
602
+ return_dict=False,
603
+ )[0]
604
+ noise_pred = noise_uncond + GUIDANCE_SCALE * (noise_pred - noise_uncond)
605
+ latents = scheduler.step(noise_pred, t, latents, return_dict=False)[0]
606
+ return latents
607
+
608
+
609
+ @spaces.GPU(duration=50)
610
+ def generate_latents_setup_on_gpu(resized_image, prompt, seed, height, width):
611
+ """Encode prompt+image, prepare latents, run the first chunk of denoising steps.
612
+
613
+ Returns a CPU-resident state dict consumable by generate_latents_chunk_on_gpu.
614
+ """
615
+ log_cuda_mem("start generate_latents_setup_on_gpu")
616
  GENERATION_PIPE.to(DEVICE)
 
 
 
617
  try:
618
+ transformer_dtype = GENERATION_PIPE.transformer.dtype
619
+
620
+ prompt_embeds, negative_prompt_embeds = GENERATION_PIPE.encode_prompt(
621
+ prompt=prompt,
622
+ negative_prompt=NEGATIVE_PROMPT,
623
+ do_classifier_free_guidance=True,
624
+ num_videos_per_prompt=1,
625
+ max_sequence_length=512,
626
+ device=DEVICE,
627
+ )
628
+ prompt_embeds = prompt_embeds.to(transformer_dtype)
629
+ negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
630
+
631
+ image_embeds = GENERATION_PIPE.encode_image(resized_image, DEVICE)
632
+ image_embeds = image_embeds.repeat(1, 1, 1).to(transformer_dtype)
633
+
634
+ GENERATION_PIPE.scheduler.set_timesteps(NUM_INFERENCE_STEPS, device=DEVICE)
635
+ timesteps = GENERATION_PIPE.scheduler.timesteps
636
+
637
+ image_tensor = GENERATION_PIPE.video_processor.preprocess(
638
+ resized_image, height=height, width=width
639
+ ).to(DEVICE, dtype=torch.float32)
640
+ generator = torch.Generator(device=DEVICE).manual_seed(seed)
641
+ latents, condition = GENERATION_PIPE.prepare_latents(
642
+ image_tensor,
643
+ 1,
644
+ GENERATION_PIPE.vae.config.z_dim,
645
+ height,
646
+ width,
647
+ NUM_FRAMES,
648
+ torch.float32,
649
+ DEVICE,
650
+ generator,
651
+ None,
652
+ None,
653
+ )
654
+
655
+ end_step = CHUNK_BOUNDARIES[0]
656
+ latents = _run_diffusion_steps(
657
+ latents,
658
+ condition,
659
+ prompt_embeds,
660
+ negative_prompt_embeds,
661
+ image_embeds,
662
+ timesteps,
663
+ 0,
664
+ end_step,
665
+ transformer_dtype,
666
+ )
667
+
668
+ state = {
669
+ "prompt_embeds": prompt_embeds.detach().cpu(),
670
+ "negative_prompt_embeds": negative_prompt_embeds.detach().cpu(),
671
+ "image_embeds": image_embeds.detach().cpu(),
672
+ "condition": condition.detach().cpu(),
673
+ "latents": latents.detach().cpu(),
674
+ "step_idx": end_step,
675
+ }
676
  finally:
677
  GENERATION_PIPE.to("cpu")
678
+ log_cuda_mem("end generate_latents_setup_on_gpu")
679
+ return state
680
+
681
+
682
+ @spaces.GPU(duration=50)
683
+ def generate_latents_chunk_on_gpu(state, end_step):
684
+ """Run denoising steps from state['step_idx'] to end_step. Only transformer is moved to GPU."""
685
+ log_cuda_mem(f"start latents chunk -> step {end_step}")
686
+ transformer = GENERATION_PIPE.transformer
687
+ transformer.to(DEVICE)
688
+ try:
689
+ GENERATION_PIPE.scheduler.set_timesteps(NUM_INFERENCE_STEPS, device=DEVICE)
690
+ timesteps = GENERATION_PIPE.scheduler.timesteps
691
+ transformer_dtype = transformer.dtype
692
+
693
+ latents = state["latents"].to(DEVICE)
694
+ condition = state["condition"].to(DEVICE)
695
+ prompt_embeds = state["prompt_embeds"].to(DEVICE)
696
+ negative_prompt_embeds = state["negative_prompt_embeds"].to(DEVICE)
697
+ image_embeds = state["image_embeds"].to(DEVICE)
698
+
699
+ latents = _run_diffusion_steps(
700
+ latents,
701
+ condition,
702
+ prompt_embeds,
703
+ negative_prompt_embeds,
704
+ image_embeds,
705
+ timesteps,
706
+ state["step_idx"],
707
+ end_step,
708
+ transformer_dtype,
709
+ )
710
+
711
+ state["latents"] = latents.detach().cpu()
712
+ state["step_idx"] = end_step
713
+ finally:
714
+ transformer.to("cpu")
715
+ log_cuda_mem(f"end latents chunk -> step {end_step}")
716
+ return state
717
 
718
 
719
  @spaces.GPU(duration=20)
 
762
  run_dir = OUTPUT_ROOT / f"refdecoder_demo_{uuid.uuid4().hex}"
763
  run_dir.mkdir(parents=True, exist_ok=True)
764
 
765
+ num_chunks = len(CHUNK_BOUNDARIES)
766
+ progress(0.0, desc=f"Generating latents (1/{num_chunks})")
767
 
768
  t0 = time.perf_counter()
769
+ resized_image, height, width = resize_image_for_wan(image, GENERATION_PIPE)
770
+ state = generate_latents_setup_on_gpu(resized_image, prompt, seed, height, width)
771
+ for chunk_idx, end_step in enumerate(CHUNK_BOUNDARIES[1:], start=2):
772
+ progress(
773
+ 0.8 * (chunk_idx - 1) / num_chunks,
774
+ desc=f"Generating latents ({chunk_idx}/{num_chunks})",
775
+ )
776
+ state = generate_latents_chunk_on_gpu(state, end_step)
777
+ latents = normalize_latent_shape(state["latents"])
778
  latent_secs = time.perf_counter() - t0
779
  print(f"[timing] latent generation: {latent_secs:.2f}s")
780
  reference_frame = build_reference_frame(resized_image, "cpu")