BoxOfColors commited on
Commit
dc0df75
·
1 Parent(s): 5031a07

fix: load only AudioLDM2 VAE+vocoder subcomponents instead of full pipeline to prevent GPU OOM on long videos

Browse files
Files changed (1) hide show
  1. app.py +10 -10
app.py CHANGED
@@ -313,7 +313,8 @@ def generate_taro(video_file, seed_val, cfg_scale, num_steps, mode,
313
  from TARO.onset_util import VideoOnsetNet, extract_onset
314
  from TARO.models import MMDiT
315
  from TARO.samplers import euler_sampler, euler_maruyama_sampler
316
- from diffusers import AudioLDM2Pipeline
 
317
 
318
  # -- Load CAVP encoder (uses checkpoint from our HF repo) --
319
  extract_cavp = Extract_CAVP_Features(
@@ -343,11 +344,10 @@ def generate_taro(video_file, seed_val, cfg_scale, num_steps, mode,
343
  model.load_state_dict(torch.load(taro_ckpt_path, map_location=device, weights_only=False)["ema"])
344
  model.eval().to(weight_dtype)
345
 
346
- # -- Load AudioLDM2 VAE + vocoder (decoder pipeline only) --
347
- # TARO uses AudioLDM2's VAE and vocoder for decoding; no encoder needed at inference
348
- audioldm2 = AudioLDM2Pipeline.from_pretrained("cvssp/audioldm2")
349
- vae = audioldm2.vae.to(device).eval()
350
- vocoder = audioldm2.vocoder.to(device)
351
  latents_scale = torch.tensor([0.18215] * 8).view(1, 8, 1, 1).to(device)
352
 
353
  # -- Prepare silent video (shared across all samples) --
@@ -873,7 +873,8 @@ def regen_taro_segment(video_file, seg_idx, seg_meta_json,
873
  from TARO.onset_util import VideoOnsetNet, extract_onset
874
  from TARO.models import MMDiT
875
  from TARO.samplers import euler_sampler, euler_maruyama_sampler
876
- from diffusers import AudioLDM2Pipeline
 
877
 
878
  silent_video = meta["silent_video"]
879
  tmp_dir = tempfile.mkdtemp()
@@ -891,9 +892,8 @@ def regen_taro_segment(video_file, seg_idx, seg_meta_json,
891
  model_net = MMDiT(adm_in_channels=120, z_dims=[768], encoder_depth=4).to(device)
892
  model_net.load_state_dict(torch.load(taro_ckpt_path, map_location=device, weights_only=False)["ema"])
893
  model_net.eval().to(weight_dtype)
894
- audioldm2 = AudioLDM2Pipeline.from_pretrained("cvssp/audioldm2")
895
- vae = audioldm2.vae.to(device).eval()
896
- vocoder = audioldm2.vocoder.to(device)
897
  latents_scale = torch.tensor([0.18215] * 8).view(1, 8, 1, 1).to(device)
898
 
899
  cavp_feats = extract_cavp(silent_video, tmp_path=tmp_dir)
 
313
  from TARO.onset_util import VideoOnsetNet, extract_onset
314
  from TARO.models import MMDiT
315
  from TARO.samplers import euler_sampler, euler_maruyama_sampler
316
+ from diffusers import AutoencoderKL
317
+ from transformers import SpeechT5HifiGan
318
 
319
  # -- Load CAVP encoder (uses checkpoint from our HF repo) --
320
  extract_cavp = Extract_CAVP_Features(
 
344
  model.load_state_dict(torch.load(taro_ckpt_path, map_location=device, weights_only=False)["ema"])
345
  model.eval().to(weight_dtype)
346
 
347
+ # -- Load AudioLDM2 VAE + vocoder only (saves ~3-4 GB vs loading the full pipeline) --
348
+ # TARO only needs VAE and vocoder for decoding; the text encoder and UNet are never used.
349
+ vae = AutoencoderKL.from_pretrained("cvssp/audioldm2", subfolder="vae").to(device).eval()
350
+ vocoder = SpeechT5HifiGan.from_pretrained("cvssp/audioldm2", subfolder="vocoder").to(device)
 
351
  latents_scale = torch.tensor([0.18215] * 8).view(1, 8, 1, 1).to(device)
352
 
353
  # -- Prepare silent video (shared across all samples) --
 
873
  from TARO.onset_util import VideoOnsetNet, extract_onset
874
  from TARO.models import MMDiT
875
  from TARO.samplers import euler_sampler, euler_maruyama_sampler
876
+ from diffusers import AutoencoderKL
877
+ from transformers import SpeechT5HifiGan
878
 
879
  silent_video = meta["silent_video"]
880
  tmp_dir = tempfile.mkdtemp()
 
892
  model_net = MMDiT(adm_in_channels=120, z_dims=[768], encoder_depth=4).to(device)
893
  model_net.load_state_dict(torch.load(taro_ckpt_path, map_location=device, weights_only=False)["ema"])
894
  model_net.eval().to(weight_dtype)
895
+ vae = AutoencoderKL.from_pretrained("cvssp/audioldm2", subfolder="vae").to(device).eval()
896
+ vocoder = SpeechT5HifiGan.from_pretrained("cvssp/audioldm2", subfolder="vocoder").to(device)
 
897
  latents_scale = torch.tensor([0.18215] * 8).view(1, 8, 1, 1).to(device)
898
 
899
  cavp_feats = extract_cavp(silent_video, tmp_path=tmp_dir)