Spaces:
Configuration error
Configuration error
Commit ·
8db3b31
1
Parent(s): 05c70dd
add memory print
Browse files
app.py
CHANGED
|
@@ -56,7 +56,25 @@ PIPE_DTYPE = torch.bfloat16 if DEVICE == "cuda" else torch.float32
|
|
| 56 |
pipeline_wan_i2v.ftfy = ftfy
|
| 57 |
|
| 58 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
def load_generation_pipe():
|
|
|
|
| 60 |
image_encoder = CLIPVisionModel.from_pretrained(
|
| 61 |
MODEL_ID,
|
| 62 |
subfolder="image_encoder",
|
|
@@ -74,10 +92,12 @@ def load_generation_pipe():
|
|
| 74 |
torch_dtype=PIPE_DTYPE,
|
| 75 |
)
|
| 76 |
pipe = pipe.to(DEVICE)
|
|
|
|
| 77 |
return pipe
|
| 78 |
|
| 79 |
|
| 80 |
def load_wan_vae():
|
|
|
|
| 81 |
vae = DiffusersWanVAE.from_pretrained(
|
| 82 |
MODEL_ID,
|
| 83 |
subfolder="vae",
|
|
@@ -85,10 +105,12 @@ def load_wan_vae():
|
|
| 85 |
)
|
| 86 |
vae = vae.to(DEVICE)
|
| 87 |
vae.eval()
|
|
|
|
| 88 |
return vae
|
| 89 |
|
| 90 |
|
| 91 |
def load_refdecoder_module():
|
|
|
|
| 92 |
vae = AutoencoderKLWan(
|
| 93 |
dropout_p=0.0,
|
| 94 |
use_reference=True,
|
|
@@ -122,6 +144,7 @@ def load_refdecoder_module():
|
|
| 122 |
|
| 123 |
vae = vae.to(DEVICE).eval()
|
| 124 |
transformer = transformer.to(DEVICE).eval()
|
|
|
|
| 125 |
return vae, transformer
|
| 126 |
|
| 127 |
|
|
@@ -207,6 +230,7 @@ def generate_and_decode(image, prompt, seed, progress=gr.Progress(track_tqdm=Fal
|
|
| 207 |
prompt = prompt.strip() if prompt else ""
|
| 208 |
seed = int(seed) if seed is not None else random.randint(0, 2**32 - 1)
|
| 209 |
run_dir = Path(tempfile.mkdtemp(prefix="refdecoder_demo_"))
|
|
|
|
| 210 |
|
| 211 |
progress(0.05, desc="Loading Wan I2V pipeline")
|
| 212 |
pipe = load_generation_pipe()
|
|
@@ -231,11 +255,13 @@ def generate_and_decode(image, prompt, seed, progress=gr.Progress(track_tqdm=Fal
|
|
| 231 |
output_type="latent",
|
| 232 |
)
|
| 233 |
latents = normalize_latent_shape(output.frames).detach().cpu()
|
|
|
|
| 234 |
del output
|
| 235 |
del pipe
|
| 236 |
if torch.cuda.is_available():
|
| 237 |
torch.cuda.empty_cache()
|
| 238 |
gc.collect()
|
|
|
|
| 239 |
|
| 240 |
latent_path = run_dir / "wan_latents.pt"
|
| 241 |
torch.save(
|
|
@@ -253,16 +279,19 @@ def generate_and_decode(image, prompt, seed, progress=gr.Progress(track_tqdm=Fal
|
|
| 253 |
wan_vae = load_wan_vae()
|
| 254 |
wan_video = decode_with_wan_vae(latents, wan_vae)
|
| 255 |
wan_video_path = save_video_tensor(wan_video, run_dir / "wan_vae.mp4")
|
|
|
|
| 256 |
del wan_video
|
| 257 |
del wan_vae
|
| 258 |
if torch.cuda.is_available():
|
| 259 |
torch.cuda.empty_cache()
|
| 260 |
gc.collect()
|
|
|
|
| 261 |
|
| 262 |
progress(0.82, desc="Decoding with RefDecoder")
|
| 263 |
ref_vae, ref_transformer = load_refdecoder_module()
|
| 264 |
ref_video = decode_with_refdecoder(latents, reference_frame, ref_vae, ref_transformer)
|
| 265 |
ref_video_path = save_video_tensor(ref_video, run_dir / "refdecoder.mp4")
|
|
|
|
| 266 |
del ref_video
|
| 267 |
del ref_vae
|
| 268 |
del ref_transformer
|
|
@@ -270,16 +299,10 @@ def generate_and_decode(image, prompt, seed, progress=gr.Progress(track_tqdm=Fal
|
|
| 270 |
if torch.cuda.is_available():
|
| 271 |
torch.cuda.empty_cache()
|
| 272 |
gc.collect()
|
|
|
|
| 273 |
|
| 274 |
-
status = (
|
| 275 |
-
f"Seed: {seed}\n"
|
| 276 |
-
f"Prompt: {prompt if prompt else '(empty)'}\n"
|
| 277 |
-
f"Resolution: {width}x{height}\n"
|
| 278 |
-
f"Frames: {NUM_FRAMES}\n"
|
| 279 |
-
f"Latents: {tuple(latents.shape)}"
|
| 280 |
-
)
|
| 281 |
progress(1.0, desc="Done")
|
| 282 |
-
return wan_video_path, ref_video_path
|
| 283 |
|
| 284 |
|
| 285 |
CUSTOM_CSS = """
|
|
@@ -292,7 +315,7 @@ CUSTOM_CSS = """
|
|
| 292 |
--accent: #1f6a52;
|
| 293 |
--accent-2: #c96f42;
|
| 294 |
--text-main: #201a14;
|
| 295 |
-
--text-soft: #
|
| 296 |
--copy-font: "Fraunces", "Iowan Old Style", "Palatino Linotype", serif;
|
| 297 |
}
|
| 298 |
|
|
@@ -392,6 +415,10 @@ CUSTOM_CSS = """
|
|
| 392 |
#generate-btn:hover {
|
| 393 |
filter: brightness(1.04);
|
| 394 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 395 |
"""
|
| 396 |
|
| 397 |
|
|
@@ -422,11 +449,11 @@ with gr.Blocks(title="RefDecoder I2V Demo", theme=gr.themes.Soft(), css=CUSTOM_C
|
|
| 422 |
image_input = gr.Image(
|
| 423 |
label="Reference Image",
|
| 424 |
type="pil",
|
| 425 |
-
height=
|
| 426 |
)
|
| 427 |
prompt_input = gr.Textbox(
|
| 428 |
label="Motion Prompt",
|
| 429 |
-
lines=
|
| 430 |
placeholder="A woman turns toward the camera as her hair moves in the wind...",
|
| 431 |
)
|
| 432 |
seed_input = gr.Number(
|
|
@@ -441,44 +468,38 @@ with gr.Blocks(title="RefDecoder I2V Demo", theme=gr.themes.Soft(), css=CUSTOM_C
|
|
| 441 |
elem_id="generate-btn",
|
| 442 |
)
|
| 443 |
|
| 444 |
-
with gr.Column(scale=6):
|
| 445 |
-
|
| 446 |
-
|
| 447 |
-
|
| 448 |
-
|
| 449 |
-
|
| 450 |
-
|
| 451 |
-
|
| 452 |
-
|
| 453 |
-
|
| 454 |
-
|
| 455 |
-
|
| 456 |
-
|
| 457 |
-
|
| 458 |
-
|
| 459 |
-
|
| 460 |
-
|
| 461 |
-
|
| 462 |
-
|
| 463 |
-
|
| 464 |
-
|
| 465 |
-
|
| 466 |
-
|
| 467 |
-
|
| 468 |
-
|
| 469 |
-
|
| 470 |
-
|
| 471 |
-
"""
|
| 472 |
-
<div class="section-title">RefDecoder Result</div>
|
| 473 |
-
<div class="section-copy">Decoded with the custom RefDecoder checkpoint.</div>
|
| 474 |
-
"""
|
| 475 |
-
)
|
| 476 |
-
ref_video_output = gr.Video(label="RefDecoder Decode", height=260)
|
| 477 |
|
| 478 |
run_button.click(
|
| 479 |
fn=generate_and_decode,
|
| 480 |
inputs=[image_input, prompt_input, seed_input],
|
| 481 |
-
outputs=[wan_video_output, ref_video_output
|
| 482 |
)
|
| 483 |
|
| 484 |
|
|
|
|
| 56 |
pipeline_wan_i2v.ftfy = ftfy
|
| 57 |
|
| 58 |
|
| 59 |
+
def log_cuda_mem(tag):
|
| 60 |
+
if not torch.cuda.is_available():
|
| 61 |
+
print(f"[mem] {tag}: CUDA not available")
|
| 62 |
+
return
|
| 63 |
+
|
| 64 |
+
free_bytes, total_bytes = torch.cuda.mem_get_info()
|
| 65 |
+
allocated_bytes = torch.cuda.memory_allocated()
|
| 66 |
+
reserved_bytes = torch.cuda.memory_reserved()
|
| 67 |
+
print(
|
| 68 |
+
f"[mem] {tag}: "
|
| 69 |
+
f"free={free_bytes / 1024**3:.2f} GB, "
|
| 70 |
+
f"total={total_bytes / 1024**3:.2f} GB, "
|
| 71 |
+
f"allocated={allocated_bytes / 1024**3:.2f} GB, "
|
| 72 |
+
f"reserved={reserved_bytes / 1024**3:.2f} GB"
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
def load_generation_pipe():
|
| 77 |
+
log_cuda_mem("before load_generation_pipe")
|
| 78 |
image_encoder = CLIPVisionModel.from_pretrained(
|
| 79 |
MODEL_ID,
|
| 80 |
subfolder="image_encoder",
|
|
|
|
| 92 |
torch_dtype=PIPE_DTYPE,
|
| 93 |
)
|
| 94 |
pipe = pipe.to(DEVICE)
|
| 95 |
+
log_cuda_mem("after load_generation_pipe")
|
| 96 |
return pipe
|
| 97 |
|
| 98 |
|
| 99 |
def load_wan_vae():
|
| 100 |
+
log_cuda_mem("before load_wan_vae")
|
| 101 |
vae = DiffusersWanVAE.from_pretrained(
|
| 102 |
MODEL_ID,
|
| 103 |
subfolder="vae",
|
|
|
|
| 105 |
)
|
| 106 |
vae = vae.to(DEVICE)
|
| 107 |
vae.eval()
|
| 108 |
+
log_cuda_mem("after load_wan_vae")
|
| 109 |
return vae
|
| 110 |
|
| 111 |
|
| 112 |
def load_refdecoder_module():
|
| 113 |
+
log_cuda_mem("before load_refdecoder_module")
|
| 114 |
vae = AutoencoderKLWan(
|
| 115 |
dropout_p=0.0,
|
| 116 |
use_reference=True,
|
|
|
|
| 144 |
|
| 145 |
vae = vae.to(DEVICE).eval()
|
| 146 |
transformer = transformer.to(DEVICE).eval()
|
| 147 |
+
log_cuda_mem("after load_refdecoder_module")
|
| 148 |
return vae, transformer
|
| 149 |
|
| 150 |
|
|
|
|
| 230 |
prompt = prompt.strip() if prompt else ""
|
| 231 |
seed = int(seed) if seed is not None else random.randint(0, 2**32 - 1)
|
| 232 |
run_dir = Path(tempfile.mkdtemp(prefix="refdecoder_demo_"))
|
| 233 |
+
log_cuda_mem("start generate_and_decode")
|
| 234 |
|
| 235 |
progress(0.05, desc="Loading Wan I2V pipeline")
|
| 236 |
pipe = load_generation_pipe()
|
|
|
|
| 255 |
output_type="latent",
|
| 256 |
)
|
| 257 |
latents = normalize_latent_shape(output.frames).detach().cpu()
|
| 258 |
+
log_cuda_mem("after latent generation")
|
| 259 |
del output
|
| 260 |
del pipe
|
| 261 |
if torch.cuda.is_available():
|
| 262 |
torch.cuda.empty_cache()
|
| 263 |
gc.collect()
|
| 264 |
+
log_cuda_mem("after freeing generation pipe")
|
| 265 |
|
| 266 |
latent_path = run_dir / "wan_latents.pt"
|
| 267 |
torch.save(
|
|
|
|
| 279 |
wan_vae = load_wan_vae()
|
| 280 |
wan_video = decode_with_wan_vae(latents, wan_vae)
|
| 281 |
wan_video_path = save_video_tensor(wan_video, run_dir / "wan_vae.mp4")
|
| 282 |
+
log_cuda_mem("after wan decode")
|
| 283 |
del wan_video
|
| 284 |
del wan_vae
|
| 285 |
if torch.cuda.is_available():
|
| 286 |
torch.cuda.empty_cache()
|
| 287 |
gc.collect()
|
| 288 |
+
log_cuda_mem("after freeing wan vae")
|
| 289 |
|
| 290 |
progress(0.82, desc="Decoding with RefDecoder")
|
| 291 |
ref_vae, ref_transformer = load_refdecoder_module()
|
| 292 |
ref_video = decode_with_refdecoder(latents, reference_frame, ref_vae, ref_transformer)
|
| 293 |
ref_video_path = save_video_tensor(ref_video, run_dir / "refdecoder.mp4")
|
| 294 |
+
log_cuda_mem("after refdecoder decode")
|
| 295 |
del ref_video
|
| 296 |
del ref_vae
|
| 297 |
del ref_transformer
|
|
|
|
| 299 |
if torch.cuda.is_available():
|
| 300 |
torch.cuda.empty_cache()
|
| 301 |
gc.collect()
|
| 302 |
+
log_cuda_mem("after freeing refdecoder")
|
| 303 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 304 |
progress(1.0, desc="Done")
|
| 305 |
+
return wan_video_path, ref_video_path
|
| 306 |
|
| 307 |
|
| 308 |
CUSTOM_CSS = """
|
|
|
|
| 315 |
--accent: #1f6a52;
|
| 316 |
--accent-2: #c96f42;
|
| 317 |
--text-main: #201a14;
|
| 318 |
+
--text-soft: #201a14;
|
| 319 |
--copy-font: "Fraunces", "Iowan Old Style", "Palatino Linotype", serif;
|
| 320 |
}
|
| 321 |
|
|
|
|
| 415 |
#generate-btn:hover {
|
| 416 |
filter: brightness(1.04);
|
| 417 |
}
|
| 418 |
+
|
| 419 |
+
.output-grid {
|
| 420 |
+
gap: 14px;
|
| 421 |
+
}
|
| 422 |
"""
|
| 423 |
|
| 424 |
|
|
|
|
| 449 |
image_input = gr.Image(
|
| 450 |
label="Reference Image",
|
| 451 |
type="pil",
|
| 452 |
+
height=320,
|
| 453 |
)
|
| 454 |
prompt_input = gr.Textbox(
|
| 455 |
label="Motion Prompt",
|
| 456 |
+
lines=4,
|
| 457 |
placeholder="A woman turns toward the camera as her hair moves in the wind...",
|
| 458 |
)
|
| 459 |
seed_input = gr.Number(
|
|
|
|
| 468 |
elem_id="generate-btn",
|
| 469 |
)
|
| 470 |
|
| 471 |
+
with gr.Column(scale=6, elem_classes="panel-card"):
|
| 472 |
+
gr.HTML(
|
| 473 |
+
"""
|
| 474 |
+
<div class="section-title">Decoder Comparison</div>
|
| 475 |
+
<div class="section-copy">
|
| 476 |
+
Same Wan latent video, rendered with two different decoders.
|
| 477 |
+
</div>
|
| 478 |
+
"""
|
| 479 |
+
)
|
| 480 |
+
with gr.Row(equal_height=True, elem_classes="output-grid"):
|
| 481 |
+
with gr.Column(elem_classes="output-card"):
|
| 482 |
+
gr.HTML(
|
| 483 |
+
"""
|
| 484 |
+
<div class="section-title">Wan Baseline</div>
|
| 485 |
+
<div class="section-copy">Decoded with Wan2.1's original VAE.</div>
|
| 486 |
+
"""
|
| 487 |
+
)
|
| 488 |
+
wan_video_output = gr.Video(label="Wan VAE Decode", height=250)
|
| 489 |
+
|
| 490 |
+
with gr.Column(elem_classes="output-card"):
|
| 491 |
+
gr.HTML(
|
| 492 |
+
"""
|
| 493 |
+
<div class="section-title">RefDecoder Result</div>
|
| 494 |
+
<div class="section-copy">Decoded with the custom RefDecoder checkpoint.</div>
|
| 495 |
+
"""
|
| 496 |
+
)
|
| 497 |
+
ref_video_output = gr.Video(label="RefDecoder Decode", height=250)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 498 |
|
| 499 |
run_button.click(
|
| 500 |
fn=generate_and_decode,
|
| 501 |
inputs=[image_input, prompt_input, seed_input],
|
| 502 |
+
outputs=[wan_video_output, ref_video_output],
|
| 503 |
)
|
| 504 |
|
| 505 |
|