Arrokothwhi commited on
Commit
8db3b31
·
1 Parent(s): 05c70dd

add memory print

Browse files
Files changed (1) hide show
  1. app.py +66 -45
app.py CHANGED
@@ -56,7 +56,25 @@ PIPE_DTYPE = torch.bfloat16 if DEVICE == "cuda" else torch.float32
56
  pipeline_wan_i2v.ftfy = ftfy
57
 
58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  def load_generation_pipe():
 
60
  image_encoder = CLIPVisionModel.from_pretrained(
61
  MODEL_ID,
62
  subfolder="image_encoder",
@@ -74,10 +92,12 @@ def load_generation_pipe():
74
  torch_dtype=PIPE_DTYPE,
75
  )
76
  pipe = pipe.to(DEVICE)
 
77
  return pipe
78
 
79
 
80
  def load_wan_vae():
 
81
  vae = DiffusersWanVAE.from_pretrained(
82
  MODEL_ID,
83
  subfolder="vae",
@@ -85,10 +105,12 @@ def load_wan_vae():
85
  )
86
  vae = vae.to(DEVICE)
87
  vae.eval()
 
88
  return vae
89
 
90
 
91
  def load_refdecoder_module():
 
92
  vae = AutoencoderKLWan(
93
  dropout_p=0.0,
94
  use_reference=True,
@@ -122,6 +144,7 @@ def load_refdecoder_module():
122
 
123
  vae = vae.to(DEVICE).eval()
124
  transformer = transformer.to(DEVICE).eval()
 
125
  return vae, transformer
126
 
127
 
@@ -207,6 +230,7 @@ def generate_and_decode(image, prompt, seed, progress=gr.Progress(track_tqdm=Fal
207
  prompt = prompt.strip() if prompt else ""
208
  seed = int(seed) if seed is not None else random.randint(0, 2**32 - 1)
209
  run_dir = Path(tempfile.mkdtemp(prefix="refdecoder_demo_"))
 
210
 
211
  progress(0.05, desc="Loading Wan I2V pipeline")
212
  pipe = load_generation_pipe()
@@ -231,11 +255,13 @@ def generate_and_decode(image, prompt, seed, progress=gr.Progress(track_tqdm=Fal
231
  output_type="latent",
232
  )
233
  latents = normalize_latent_shape(output.frames).detach().cpu()
 
234
  del output
235
  del pipe
236
  if torch.cuda.is_available():
237
  torch.cuda.empty_cache()
238
  gc.collect()
 
239
 
240
  latent_path = run_dir / "wan_latents.pt"
241
  torch.save(
@@ -253,16 +279,19 @@ def generate_and_decode(image, prompt, seed, progress=gr.Progress(track_tqdm=Fal
253
  wan_vae = load_wan_vae()
254
  wan_video = decode_with_wan_vae(latents, wan_vae)
255
  wan_video_path = save_video_tensor(wan_video, run_dir / "wan_vae.mp4")
 
256
  del wan_video
257
  del wan_vae
258
  if torch.cuda.is_available():
259
  torch.cuda.empty_cache()
260
  gc.collect()
 
261
 
262
  progress(0.82, desc="Decoding with RefDecoder")
263
  ref_vae, ref_transformer = load_refdecoder_module()
264
  ref_video = decode_with_refdecoder(latents, reference_frame, ref_vae, ref_transformer)
265
  ref_video_path = save_video_tensor(ref_video, run_dir / "refdecoder.mp4")
 
266
  del ref_video
267
  del ref_vae
268
  del ref_transformer
@@ -270,16 +299,10 @@ def generate_and_decode(image, prompt, seed, progress=gr.Progress(track_tqdm=Fal
270
  if torch.cuda.is_available():
271
  torch.cuda.empty_cache()
272
  gc.collect()
 
273
 
274
- status = (
275
- f"Seed: {seed}\n"
276
- f"Prompt: {prompt if prompt else '(empty)'}\n"
277
- f"Resolution: {width}x{height}\n"
278
- f"Frames: {NUM_FRAMES}\n"
279
- f"Latents: {tuple(latents.shape)}"
280
- )
281
  progress(1.0, desc="Done")
282
- return wan_video_path, ref_video_path, status
283
 
284
 
285
  CUSTOM_CSS = """
@@ -292,7 +315,7 @@ CUSTOM_CSS = """
292
  --accent: #1f6a52;
293
  --accent-2: #c96f42;
294
  --text-main: #201a14;
295
- --text-soft: #5c5348;
296
  --copy-font: "Fraunces", "Iowan Old Style", "Palatino Linotype", serif;
297
  }
298
 
@@ -392,6 +415,10 @@ CUSTOM_CSS = """
392
  #generate-btn:hover {
393
  filter: brightness(1.04);
394
  }
 
 
 
 
395
  """
396
 
397
 
@@ -422,11 +449,11 @@ with gr.Blocks(title="RefDecoder I2V Demo", theme=gr.themes.Soft(), css=CUSTOM_C
422
  image_input = gr.Image(
423
  label="Reference Image",
424
  type="pil",
425
- height=420,
426
  )
427
  prompt_input = gr.Textbox(
428
  label="Motion Prompt",
429
- lines=5,
430
  placeholder="A woman turns toward the camera as her hair moves in the wind...",
431
  )
432
  seed_input = gr.Number(
@@ -441,44 +468,38 @@ with gr.Blocks(title="RefDecoder I2V Demo", theme=gr.themes.Soft(), css=CUSTOM_C
441
  elem_id="generate-btn",
442
  )
443
 
444
- with gr.Column(scale=6):
445
- with gr.Column(elem_classes="panel-card"):
446
- gr.HTML(
447
- """
448
- <div class="section-title">Run Info</div>
449
- <div class="section-copy">
450
- Generation details for the current comparison run.
451
- </div>
452
- """
453
- )
454
- status_output = gr.Textbox(
455
- label="Run Info",
456
- lines=7,
457
- interactive=False,
458
- )
459
-
460
- with gr.Column(elem_classes="output-card"):
461
- gr.HTML(
462
- """
463
- <div class="section-title">Wan Baseline</div>
464
- <div class="section-copy">Decoded with Wan2.1's original VAE.</div>
465
- """
466
- )
467
- wan_video_output = gr.Video(label="Wan VAE Decode", height=260)
468
-
469
- with gr.Column(elem_classes="output-card"):
470
- gr.HTML(
471
- """
472
- <div class="section-title">RefDecoder Result</div>
473
- <div class="section-copy">Decoded with the custom RefDecoder checkpoint.</div>
474
- """
475
- )
476
- ref_video_output = gr.Video(label="RefDecoder Decode", height=260)
477
 
478
  run_button.click(
479
  fn=generate_and_decode,
480
  inputs=[image_input, prompt_input, seed_input],
481
- outputs=[wan_video_output, ref_video_output, status_output],
482
  )
483
 
484
 
 
56
  pipeline_wan_i2v.ftfy = ftfy
57
 
58
 
59
+ def log_cuda_mem(tag):
60
+ if not torch.cuda.is_available():
61
+ print(f"[mem] {tag}: CUDA not available")
62
+ return
63
+
64
+ free_bytes, total_bytes = torch.cuda.mem_get_info()
65
+ allocated_bytes = torch.cuda.memory_allocated()
66
+ reserved_bytes = torch.cuda.memory_reserved()
67
+ print(
68
+ f"[mem] {tag}: "
69
+ f"free={free_bytes / 1024**3:.2f} GB, "
70
+ f"total={total_bytes / 1024**3:.2f} GB, "
71
+ f"allocated={allocated_bytes / 1024**3:.2f} GB, "
72
+ f"reserved={reserved_bytes / 1024**3:.2f} GB"
73
+ )
74
+
75
+
76
  def load_generation_pipe():
77
+ log_cuda_mem("before load_generation_pipe")
78
  image_encoder = CLIPVisionModel.from_pretrained(
79
  MODEL_ID,
80
  subfolder="image_encoder",
 
92
  torch_dtype=PIPE_DTYPE,
93
  )
94
  pipe = pipe.to(DEVICE)
95
+ log_cuda_mem("after load_generation_pipe")
96
  return pipe
97
 
98
 
99
  def load_wan_vae():
100
+ log_cuda_mem("before load_wan_vae")
101
  vae = DiffusersWanVAE.from_pretrained(
102
  MODEL_ID,
103
  subfolder="vae",
 
105
  )
106
  vae = vae.to(DEVICE)
107
  vae.eval()
108
+ log_cuda_mem("after load_wan_vae")
109
  return vae
110
 
111
 
112
  def load_refdecoder_module():
113
+ log_cuda_mem("before load_refdecoder_module")
114
  vae = AutoencoderKLWan(
115
  dropout_p=0.0,
116
  use_reference=True,
 
144
 
145
  vae = vae.to(DEVICE).eval()
146
  transformer = transformer.to(DEVICE).eval()
147
+ log_cuda_mem("after load_refdecoder_module")
148
  return vae, transformer
149
 
150
 
 
230
  prompt = prompt.strip() if prompt else ""
231
  seed = int(seed) if seed is not None else random.randint(0, 2**32 - 1)
232
  run_dir = Path(tempfile.mkdtemp(prefix="refdecoder_demo_"))
233
+ log_cuda_mem("start generate_and_decode")
234
 
235
  progress(0.05, desc="Loading Wan I2V pipeline")
236
  pipe = load_generation_pipe()
 
255
  output_type="latent",
256
  )
257
  latents = normalize_latent_shape(output.frames).detach().cpu()
258
+ log_cuda_mem("after latent generation")
259
  del output
260
  del pipe
261
  if torch.cuda.is_available():
262
  torch.cuda.empty_cache()
263
  gc.collect()
264
+ log_cuda_mem("after freeing generation pipe")
265
 
266
  latent_path = run_dir / "wan_latents.pt"
267
  torch.save(
 
279
  wan_vae = load_wan_vae()
280
  wan_video = decode_with_wan_vae(latents, wan_vae)
281
  wan_video_path = save_video_tensor(wan_video, run_dir / "wan_vae.mp4")
282
+ log_cuda_mem("after wan decode")
283
  del wan_video
284
  del wan_vae
285
  if torch.cuda.is_available():
286
  torch.cuda.empty_cache()
287
  gc.collect()
288
+ log_cuda_mem("after freeing wan vae")
289
 
290
  progress(0.82, desc="Decoding with RefDecoder")
291
  ref_vae, ref_transformer = load_refdecoder_module()
292
  ref_video = decode_with_refdecoder(latents, reference_frame, ref_vae, ref_transformer)
293
  ref_video_path = save_video_tensor(ref_video, run_dir / "refdecoder.mp4")
294
+ log_cuda_mem("after refdecoder decode")
295
  del ref_video
296
  del ref_vae
297
  del ref_transformer
 
299
  if torch.cuda.is_available():
300
  torch.cuda.empty_cache()
301
  gc.collect()
302
+ log_cuda_mem("after freeing refdecoder")
303
 
 
 
 
 
 
 
 
304
  progress(1.0, desc="Done")
305
+ return wan_video_path, ref_video_path
306
 
307
 
308
  CUSTOM_CSS = """
 
315
  --accent: #1f6a52;
316
  --accent-2: #c96f42;
317
  --text-main: #201a14;
318
+ --text-soft: #201a14;
319
  --copy-font: "Fraunces", "Iowan Old Style", "Palatino Linotype", serif;
320
  }
321
 
 
415
  #generate-btn:hover {
416
  filter: brightness(1.04);
417
  }
418
+
419
+ .output-grid {
420
+ gap: 14px;
421
+ }
422
  """
423
 
424
 
 
449
  image_input = gr.Image(
450
  label="Reference Image",
451
  type="pil",
452
+ height=320,
453
  )
454
  prompt_input = gr.Textbox(
455
  label="Motion Prompt",
456
+ lines=4,
457
  placeholder="A woman turns toward the camera as her hair moves in the wind...",
458
  )
459
  seed_input = gr.Number(
 
468
  elem_id="generate-btn",
469
  )
470
 
471
+ with gr.Column(scale=6, elem_classes="panel-card"):
472
+ gr.HTML(
473
+ """
474
+ <div class="section-title">Decoder Comparison</div>
475
+ <div class="section-copy">
476
+ Same Wan latent video, rendered with two different decoders.
477
+ </div>
478
+ """
479
+ )
480
+ with gr.Row(equal_height=True, elem_classes="output-grid"):
481
+ with gr.Column(elem_classes="output-card"):
482
+ gr.HTML(
483
+ """
484
+ <div class="section-title">Wan Baseline</div>
485
+ <div class="section-copy">Decoded with Wan2.1's original VAE.</div>
486
+ """
487
+ )
488
+ wan_video_output = gr.Video(label="Wan VAE Decode", height=250)
489
+
490
+ with gr.Column(elem_classes="output-card"):
491
+ gr.HTML(
492
+ """
493
+ <div class="section-title">RefDecoder Result</div>
494
+ <div class="section-copy">Decoded with the custom RefDecoder checkpoint.</div>
495
+ """
496
+ )
497
+ ref_video_output = gr.Video(label="RefDecoder Decode", height=250)
 
 
 
 
 
 
498
 
499
  run_button.click(
500
  fn=generate_and_decode,
501
  inputs=[image_input, prompt_input, seed_input],
502
+ outputs=[wan_video_output, ref_video_output],
503
  )
504
 
505