Spaces:
Running on Zero
Running on Zero
Commit ·
dc0df75
1
Parent(s): 5031a07
fix: load only AudioLDM2 VAE+vocoder subcomponents instead of full pipeline to prevent GPU OOM on long videos
Browse files
app.py
CHANGED
|
@@ -313,7 +313,8 @@ def generate_taro(video_file, seed_val, cfg_scale, num_steps, mode,
|
|
| 313 |
from TARO.onset_util import VideoOnsetNet, extract_onset
|
| 314 |
from TARO.models import MMDiT
|
| 315 |
from TARO.samplers import euler_sampler, euler_maruyama_sampler
|
| 316 |
-
from diffusers import
|
|
|
|
| 317 |
|
| 318 |
# -- Load CAVP encoder (uses checkpoint from our HF repo) --
|
| 319 |
extract_cavp = Extract_CAVP_Features(
|
|
@@ -343,11 +344,10 @@ def generate_taro(video_file, seed_val, cfg_scale, num_steps, mode,
|
|
| 343 |
model.load_state_dict(torch.load(taro_ckpt_path, map_location=device, weights_only=False)["ema"])
|
| 344 |
model.eval().to(weight_dtype)
|
| 345 |
|
| 346 |
-
# -- Load AudioLDM2 VAE + vocoder (
|
| 347 |
-
# TARO
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
vocoder = audioldm2.vocoder.to(device)
|
| 351 |
latents_scale = torch.tensor([0.18215] * 8).view(1, 8, 1, 1).to(device)
|
| 352 |
|
| 353 |
# -- Prepare silent video (shared across all samples) --
|
|
@@ -873,7 +873,8 @@ def regen_taro_segment(video_file, seg_idx, seg_meta_json,
|
|
| 873 |
from TARO.onset_util import VideoOnsetNet, extract_onset
|
| 874 |
from TARO.models import MMDiT
|
| 875 |
from TARO.samplers import euler_sampler, euler_maruyama_sampler
|
| 876 |
-
from diffusers import
|
|
|
|
| 877 |
|
| 878 |
silent_video = meta["silent_video"]
|
| 879 |
tmp_dir = tempfile.mkdtemp()
|
|
@@ -891,9 +892,8 @@ def regen_taro_segment(video_file, seg_idx, seg_meta_json,
|
|
| 891 |
model_net = MMDiT(adm_in_channels=120, z_dims=[768], encoder_depth=4).to(device)
|
| 892 |
model_net.load_state_dict(torch.load(taro_ckpt_path, map_location=device, weights_only=False)["ema"])
|
| 893 |
model_net.eval().to(weight_dtype)
|
| 894 |
-
|
| 895 |
-
|
| 896 |
-
vocoder = audioldm2.vocoder.to(device)
|
| 897 |
latents_scale = torch.tensor([0.18215] * 8).view(1, 8, 1, 1).to(device)
|
| 898 |
|
| 899 |
cavp_feats = extract_cavp(silent_video, tmp_path=tmp_dir)
|
|
|
|
| 313 |
from TARO.onset_util import VideoOnsetNet, extract_onset
|
| 314 |
from TARO.models import MMDiT
|
| 315 |
from TARO.samplers import euler_sampler, euler_maruyama_sampler
|
| 316 |
+
from diffusers import AutoencoderKL
|
| 317 |
+
from transformers import SpeechT5HifiGan
|
| 318 |
|
| 319 |
# -- Load CAVP encoder (uses checkpoint from our HF repo) --
|
| 320 |
extract_cavp = Extract_CAVP_Features(
|
|
|
|
| 344 |
model.load_state_dict(torch.load(taro_ckpt_path, map_location=device, weights_only=False)["ema"])
|
| 345 |
model.eval().to(weight_dtype)
|
| 346 |
|
| 347 |
+
# -- Load AudioLDM2 VAE + vocoder only (saves ~3-4 GB vs loading the full pipeline) --
|
| 348 |
+
# TARO only needs VAE and vocoder for decoding; the text encoder and UNet are never used.
|
| 349 |
+
vae = AutoencoderKL.from_pretrained("cvssp/audioldm2", subfolder="vae").to(device).eval()
|
| 350 |
+
vocoder = SpeechT5HifiGan.from_pretrained("cvssp/audioldm2", subfolder="vocoder").to(device)
|
|
|
|
| 351 |
latents_scale = torch.tensor([0.18215] * 8).view(1, 8, 1, 1).to(device)
|
| 352 |
|
| 353 |
# -- Prepare silent video (shared across all samples) --
|
|
|
|
| 873 |
from TARO.onset_util import VideoOnsetNet, extract_onset
|
| 874 |
from TARO.models import MMDiT
|
| 875 |
from TARO.samplers import euler_sampler, euler_maruyama_sampler
|
| 876 |
+
from diffusers import AutoencoderKL
|
| 877 |
+
from transformers import SpeechT5HifiGan
|
| 878 |
|
| 879 |
silent_video = meta["silent_video"]
|
| 880 |
tmp_dir = tempfile.mkdtemp()
|
|
|
|
| 892 |
model_net = MMDiT(adm_in_channels=120, z_dims=[768], encoder_depth=4).to(device)
|
| 893 |
model_net.load_state_dict(torch.load(taro_ckpt_path, map_location=device, weights_only=False)["ema"])
|
| 894 |
model_net.eval().to(weight_dtype)
|
| 895 |
+
vae = AutoencoderKL.from_pretrained("cvssp/audioldm2", subfolder="vae").to(device).eval()
|
| 896 |
+
vocoder = SpeechT5HifiGan.from_pretrained("cvssp/audioldm2", subfolder="vocoder").to(device)
|
|
|
|
| 897 |
latents_scale = torch.tensor([0.18215] * 8).view(1, 8, 1, 1).to(device)
|
| 898 |
|
| 899 |
cavp_feats = extract_cavp(silent_video, tmp_path=tmp_dir)
|