prismaudio-project commited on
Commit
303133d
·
1 Parent(s): c367d8b
Files changed (1) hide show
  1. app.py +10 -8
app.py CHANGED
@@ -59,7 +59,8 @@ CKPT_PATH = "ckpts/prismaudio.ckpt"
59
  VAE_CKPT_PATH = "ckpts/vae.ckpt"
60
  VAE_CONFIG_PATH = "PrismAudio/configs/model_configs/stable_audio_2_0_vae.json"
61
  SYNCHFORMER_CKPT_PATH = "ckpts/synchformer_state_dict.pth"
62
- DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'
 
63
 
64
  # ==================== Global Model Registry ====================
65
  _MODELS = {
@@ -96,7 +97,7 @@ def load_all_models():
96
  enable_conditions=True,
97
  synchformer_ckpt=SYNCHFORMER_CKPT_PATH,
98
  )
99
- feature_extractor = feature_extractor.eval().to(DEVICE)
100
  _MODELS["feature_extractor"] = feature_extractor
101
  log.info("✅ FeaturesUtils loaded")
102
 
@@ -114,7 +115,7 @@ def load_all_models():
114
  vae_state = load_ckpt_state_dict(VAE_CKPT_PATH, prefix='autoencoder.')
115
  diffusion.pretransform.load_state_dict(vae_state)
116
 
117
- diffusion = diffusion.eval().to(DEVICE)
118
  _MODELS["diffusion"] = diffusion
119
  log.info("✅ Diffusion model loaded")
120
 
@@ -353,7 +354,10 @@ def run_diffusion(audio_latent: torch.Tensor, meta: dict, duration: float) -> to
353
  @spaces.GPU
354
  def generate_audio_core(video_file, caption):
355
 
356
-
 
 
 
357
  start_time =time.time()
358
 
359
  """
@@ -379,11 +383,9 @@ def generate_audio_core(video_file, caption):
379
  return "\n".join(logs)
380
 
381
  # ---- Working directory (auto-cleaned on exit) ----
382
- work_dir = tempfile.mkdtemp(dir=os.environ["GRADIO_TEMP_DIR"], prefix="PrismAudio_")
383
 
384
  try:
385
- if _MODELS["diffusion"] is None:
386
- load_all_models()
387
  # ---- Step 1: Convert / copy to mp4 ----
388
  status = log_step("📹 Step 1: Preparing video...")
389
 
@@ -619,7 +621,7 @@ if __name__ == "__main__":
619
  log.info("✅ All model files found.")
620
 
621
  # ⭐ Load all models once at startup
622
- #load_all_models()
623
 
624
  demo = build_ui()
625
  demo.queue(max_size=3)
 
59
  VAE_CKPT_PATH = "ckpts/vae.ckpt"
60
  VAE_CONFIG_PATH = "PrismAudio/configs/model_configs/stable_audio_2_0_vae.json"
61
  SYNCHFORMER_CKPT_PATH = "ckpts/synchformer_state_dict.pth"
62
+ DEVICE = 'cpu' # 启动时用CPU
63
+
64
 
65
  # ==================== Global Model Registry ====================
66
  _MODELS = {
 
97
  enable_conditions=True,
98
  synchformer_ckpt=SYNCHFORMER_CKPT_PATH,
99
  )
100
+ feature_extractor = feature_extractor.eval()
101
  _MODELS["feature_extractor"] = feature_extractor
102
  log.info("✅ FeaturesUtils loaded")
103
 
 
115
  vae_state = load_ckpt_state_dict(VAE_CKPT_PATH, prefix='autoencoder.')
116
  diffusion.pretransform.load_state_dict(vae_state)
117
 
118
+ diffusion = diffusion.eval()
119
  _MODELS["diffusion"] = diffusion
120
  log.info("✅ Diffusion model loaded")
121
 
 
354
  @spaces.GPU
355
  def generate_audio_core(video_file, caption):
356
 
357
+ global DEVICE
358
+ DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
359
+ _MODELS["feature_extractor"].to(DEVICE)
360
+ _MODELS["diffusion"].to(DEVICE)
361
  start_time =time.time()
362
 
363
  """
 
383
  return "\n".join(logs)
384
 
385
  # ---- Working directory (auto-cleaned on exit) ----
386
+ work_dir = tempfile.mkdtemp(prefix="PrismAudio_")
387
 
388
  try:
 
 
389
  # ---- Step 1: Convert / copy to mp4 ----
390
  status = log_step("📹 Step 1: Preparing video...")
391
 
 
621
  log.info("✅ All model files found.")
622
 
623
  # ⭐ Load all models once at startup
624
+ load_all_models()
625
 
626
  demo = build_ui()
627
  demo.queue(max_size=3)