xiangfan00 commited on
Commit
6f279fd
·
1 Parent(s): 3a3a124
Files changed (1) hide show
  1. app.py +24 -33
app.py CHANGED
@@ -113,7 +113,6 @@ def get_module_dtype(module):
113
 
114
 
115
  def load_generation_pipe():
116
- log_cuda_mem("before load_generation_pipe")
117
  image_encoder = CLIPVisionModel.from_pretrained(
118
  MODEL_ID,
119
  subfolder="image_encoder",
@@ -130,24 +129,20 @@ def load_generation_pipe():
130
  image_encoder=image_encoder,
131
  torch_dtype=PIPE_DTYPE,
132
  )
133
- log_cuda_mem("after load_generation_pipe")
134
  return pipe
135
 
136
 
137
  def load_wan_vae():
138
- log_cuda_mem("before load_wan_vae")
139
  vae = DiffusersWanVAE.from_pretrained(
140
  MODEL_ID,
141
  subfolder="vae",
142
  torch_dtype=PIPE_DTYPE,
143
  )
144
  vae.eval()
145
- log_cuda_mem("after load_wan_vae")
146
  return vae
147
 
148
 
149
  def load_refdecoder_module():
150
- log_cuda_mem("before load_refdecoder_module")
151
  vae = AutoencoderKLWan(
152
  dropout_p=0.0,
153
  use_reference=True,
@@ -175,7 +170,6 @@ def load_refdecoder_module():
175
  vae.load_state_dict(vae_sd, strict=False)
176
  transformer.load_state_dict(transformer_sd, strict=False)
177
 
178
- log_cuda_mem("after load_refdecoder_module")
179
  return vae, transformer
180
 
181
 
@@ -561,7 +555,11 @@ def decode_with_refdecoder(latents, reference_frame, vae, transformer):
561
  return video
562
 
563
 
564
- CHUNK_BOUNDARIES = (8, 16, 23, NUM_INFERENCE_STEPS)
 
 
 
 
565
  assert CHUNK_BOUNDARIES[-1] == NUM_INFERENCE_STEPS
566
 
567
 
@@ -608,12 +606,18 @@ def _run_diffusion_steps(
608
 
609
  @spaces.GPU(duration=50)
610
  def generate_latents_setup_on_gpu(resized_image, prompt, seed, height, width):
611
- """Encode prompt+image, prepare latents, run the first chunk of denoising steps.
612
 
613
- Returns a CPU-resident state dict consumable by generate_latents_chunk_on_gpu.
 
614
  """
615
  log_cuda_mem("start generate_latents_setup_on_gpu")
616
- GENERATION_PIPE.to(DEVICE)
 
 
 
 
 
617
  try:
618
  transformer_dtype = GENERATION_PIPE.transformer.dtype
619
 
@@ -631,9 +635,6 @@ def generate_latents_setup_on_gpu(resized_image, prompt, seed, height, width):
631
  image_embeds = GENERATION_PIPE.encode_image(resized_image, DEVICE)
632
  image_embeds = image_embeds.repeat(1, 1, 1).to(transformer_dtype)
633
 
634
- GENERATION_PIPE.scheduler.set_timesteps(NUM_INFERENCE_STEPS, device=DEVICE)
635
- timesteps = GENERATION_PIPE.scheduler.timesteps
636
-
637
  image_tensor = GENERATION_PIPE.video_processor.preprocess(
638
  resized_image, height=height, width=width
639
  ).to(DEVICE, dtype=torch.float32)
@@ -652,29 +653,18 @@ def generate_latents_setup_on_gpu(resized_image, prompt, seed, height, width):
652
  None,
653
  )
654
 
655
- end_step = CHUNK_BOUNDARIES[0]
656
- latents = _run_diffusion_steps(
657
- latents,
658
- condition,
659
- prompt_embeds,
660
- negative_prompt_embeds,
661
- image_embeds,
662
- timesteps,
663
- 0,
664
- end_step,
665
- transformer_dtype,
666
- )
667
-
668
  state = {
669
  "prompt_embeds": prompt_embeds.detach().cpu(),
670
  "negative_prompt_embeds": negative_prompt_embeds.detach().cpu(),
671
  "image_embeds": image_embeds.detach().cpu(),
672
  "condition": condition.detach().cpu(),
673
  "latents": latents.detach().cpu(),
674
- "step_idx": end_step,
675
  }
676
  finally:
677
- GENERATION_PIPE.to("cpu")
 
 
678
  log_cuda_mem("end generate_latents_setup_on_gpu")
679
  return state
680
 
@@ -762,16 +752,17 @@ def generate_and_decode(image, prompt, seed, progress=gr.Progress()):
762
  run_dir = OUTPUT_ROOT / f"refdecoder_demo_{uuid.uuid4().hex}"
763
  run_dir.mkdir(parents=True, exist_ok=True)
764
 
765
- num_chunks = len(CHUNK_BOUNDARIES)
766
- progress(0.0, desc=f"Generating latents (1/{num_chunks})")
 
767
 
768
  t0 = time.perf_counter()
769
  resized_image, height, width = resize_image_for_wan(image, GENERATION_PIPE)
770
  state = generate_latents_setup_on_gpu(resized_image, prompt, seed, height, width)
771
- for chunk_idx, end_step in enumerate(CHUNK_BOUNDARIES[1:], start=2):
772
  progress(
773
- 0.8 * (chunk_idx - 1) / num_chunks,
774
- desc=f"Generating latents ({chunk_idx}/{num_chunks})",
775
  )
776
  state = generate_latents_chunk_on_gpu(state, end_step)
777
  latents = normalize_latent_shape(state["latents"])
 
113
 
114
 
115
  def load_generation_pipe():
 
116
  image_encoder = CLIPVisionModel.from_pretrained(
117
  MODEL_ID,
118
  subfolder="image_encoder",
 
129
  image_encoder=image_encoder,
130
  torch_dtype=PIPE_DTYPE,
131
  )
 
132
  return pipe
133
 
134
 
135
  def load_wan_vae():
 
136
  vae = DiffusersWanVAE.from_pretrained(
137
  MODEL_ID,
138
  subfolder="vae",
139
  torch_dtype=PIPE_DTYPE,
140
  )
141
  vae.eval()
 
142
  return vae
143
 
144
 
145
  def load_refdecoder_module():
 
146
  vae = AutoencoderKLWan(
147
  dropout_p=0.0,
148
  use_reference=True,
 
170
  vae.load_state_dict(vae_sd, strict=False)
171
  transformer.load_state_dict(transformer_sd, strict=False)
172
 
 
173
  return vae, transformer
174
 
175
 
 
555
  return video
556
 
557
 
558
+ _NUM_DENOISING_CHUNKS = 4
559
+ CHUNK_BOUNDARIES = tuple(
560
+ NUM_INFERENCE_STEPS * (i + 1) // _NUM_DENOISING_CHUNKS
561
+ for i in range(_NUM_DENOISING_CHUNKS)
562
+ )
563
  assert CHUNK_BOUNDARIES[-1] == NUM_INFERENCE_STEPS
564
 
565
 
 
606
 
607
  @spaces.GPU(duration=50)
608
  def generate_latents_setup_on_gpu(resized_image, prompt, seed, height, width):
609
+ """Encode prompt+image, prepare initial latents and condition. NO denoising.
610
 
611
+ Loads only the encoders + VAE to GPU (not the 14B transformer). Returns a
612
+ CPU-resident state dict consumable by generate_latents_chunk_on_gpu.
613
  """
614
  log_cuda_mem("start generate_latents_setup_on_gpu")
615
+ text_encoder = GENERATION_PIPE.text_encoder
616
+ image_encoder = GENERATION_PIPE.image_encoder
617
+ vae = GENERATION_PIPE.vae
618
+ text_encoder.to(DEVICE)
619
+ image_encoder.to(DEVICE)
620
+ vae.to(DEVICE)
621
  try:
622
  transformer_dtype = GENERATION_PIPE.transformer.dtype
623
 
 
635
  image_embeds = GENERATION_PIPE.encode_image(resized_image, DEVICE)
636
  image_embeds = image_embeds.repeat(1, 1, 1).to(transformer_dtype)
637
 
 
 
 
638
  image_tensor = GENERATION_PIPE.video_processor.preprocess(
639
  resized_image, height=height, width=width
640
  ).to(DEVICE, dtype=torch.float32)
 
653
  None,
654
  )
655
 
 
 
 
 
 
 
 
 
 
 
 
 
 
656
  state = {
657
  "prompt_embeds": prompt_embeds.detach().cpu(),
658
  "negative_prompt_embeds": negative_prompt_embeds.detach().cpu(),
659
  "image_embeds": image_embeds.detach().cpu(),
660
  "condition": condition.detach().cpu(),
661
  "latents": latents.detach().cpu(),
662
+ "step_idx": 0,
663
  }
664
  finally:
665
+ text_encoder.to("cpu")
666
+ image_encoder.to("cpu")
667
+ vae.to("cpu")
668
  log_cuda_mem("end generate_latents_setup_on_gpu")
669
  return state
670
 
 
752
  run_dir = OUTPUT_ROOT / f"refdecoder_demo_{uuid.uuid4().hex}"
753
  run_dir.mkdir(parents=True, exist_ok=True)
754
 
755
+ # 1 setup chunk (encoders + VAE) + len(CHUNK_BOUNDARIES) denoising chunks.
756
+ total_chunks = 1 + len(CHUNK_BOUNDARIES)
757
+ progress(0.0, desc=f"Generating latents (1/{total_chunks})")
758
 
759
  t0 = time.perf_counter()
760
  resized_image, height, width = resize_image_for_wan(image, GENERATION_PIPE)
761
  state = generate_latents_setup_on_gpu(resized_image, prompt, seed, height, width)
762
+ for chunk_idx, end_step in enumerate(CHUNK_BOUNDARIES, start=2):
763
  progress(
764
+ 0.8 * (chunk_idx - 1) / total_chunks,
765
+ desc=f"Generating latents ({chunk_idx}/{total_chunks})",
766
  )
767
  state = generate_latents_chunk_on_gpu(state, end_step)
768
  latents = normalize_latent_shape(state["latents"])