prismaudio-project commited on
Commit
72b3eff
Β·
1 Parent(s): be7e0ef
Files changed (1) hide show
  1. app.py +34 -29
app.py CHANGED
@@ -345,7 +345,11 @@ def run_diffusion(audio_latent: torch.Tensor, meta: dict, duration: float) -> to
345
 
346
  # ==================== Full Inference Pipeline ====================
347
 
348
- def generate_audio(video_file, caption: str):
 
 
 
 
349
  start_time =time.time()
350
 
351
  """
@@ -356,11 +360,11 @@ def generate_audio(video_file, caption: str):
356
  """
357
  # ---- Basic validation ----
358
  if video_file is None:
359
- yield "❌ Please upload a video file first.", None
360
- return
361
  if not caption or caption.strip() == "":
362
- yield "❌ Please enter a caption / prompt.", None
363
- return
364
 
365
  caption = caption.strip()
366
  logs = []
@@ -375,9 +379,7 @@ def generate_audio(video_file, caption: str):
375
 
376
  try:
377
  # ---- Step 1: Convert / copy to mp4 ----
378
- status = log_step("πŸ“Ή Step 1: Preparing video...")
379
-
380
- yield status, None
381
 
382
  src_ext = os.path.splitext(video_file)[1].lower()
383
  mp4_path = os.path.join(work_dir, "input.mp4")
@@ -386,30 +388,28 @@ def generate_audio(video_file, caption: str):
386
  log_step(" Converting to mp4...")
387
  ok, err = convert_to_mp4(video_file, mp4_path)
388
  if not ok:
389
- yield log_step(f"❌ Video conversion failed:\n{err}"), None
390
- return
391
  else:
392
  shutil.copy(video_file, mp4_path)
393
  log_step(" Video ready.")
394
 
395
  # ---- Step 2: Validate duration ----
396
- status = log_step("πŸ“Ή Step 2: Checking video duration...")
397
- yield status, None
398
 
399
  duration = get_video_duration(mp4_path)
400
  log_step(f" Duration: {duration:.2f}s")
401
 
402
  # ---- Step 3: Extract video frames ----
403
- status = log_step("🎞️ Step 3: Extracting video frames (clip & sync)...")
404
- yield status, None
405
 
406
  clip_chunk, sync_chunk, duration = extract_video_frames(mp4_path)
407
  log_step(f" clip_chunk : {tuple(clip_chunk.shape)}")
408
  log_step(f" sync_chunk : {tuple(sync_chunk.shape)}")
409
 
410
  # ---- Step 4: Extract model features ----
411
- status = log_step("🧠 Step 4: Extracting text / video / sync features...")
412
- yield status, None
413
 
414
  info = extract_features(clip_chunk, sync_chunk, caption)
415
  log_step(f" text_features : {tuple(info['text_features'].shape)}")
@@ -419,22 +419,22 @@ def generate_audio(video_file, caption: str):
419
  log_step(f" sync_features : {tuple(info['sync_features'].shape)}")
420
 
421
  # ---- Step 5: Build inference batch ----
422
- status = log_step("πŸ“¦ Step 5: Building inference batch...")
423
- yield status, None
424
 
425
  audio_latent, meta = build_meta(info, duration, caption)
426
  log_step(f" audio_latent : {tuple(audio_latent.shape)}")
427
 
428
  # ---- Step 6: Diffusion sampling ----
429
- status = log_step("🎡 Step 6: Running diffusion sampling...")
430
- yield status, None
431
 
432
  generated_audio = run_diffusion(audio_latent, meta, duration)
433
  log_step(f" Generated audio shape : {tuple(generated_audio.shape)}")
434
 
435
  # ---- Step 7: Save generated audio (temp) ----
436
- status = log_step("πŸ’Ύ Step 7: Saving generated audio...")
437
- yield status, None
438
 
439
  audio_path = os.path.join(work_dir, "generated_audio.wav")
440
  torchaudio.save(
@@ -445,22 +445,22 @@ def generate_audio(video_file, caption: str):
445
  log_step(f" Audio saved: {audio_path}")
446
 
447
  # ---- Step 8: Mux audio into original video ----
448
- status = log_step("🎬 Step 8: Merging audio into video...")
449
- yield status, None
450
 
451
  combined_path = os.path.join(work_dir, "output_with_audio.mp4")
452
  ok, err = combine_audio_video(mp4_path, audio_path, combined_path)
453
  if not ok:
454
- yield log_step(f"❌ Failed to combine audio and video:\n{err}"), None
455
- return
456
 
457
  log_step("βœ… Done! Audio and video merged successfully.")
458
- yield "\n".join(logs), combined_path
459
 
460
  except Exception as e:
461
  log_step(f"❌ Unexpected error: {str(e)}")
462
  log.exception(e)
463
- yield "\n".join(logs), None
464
 
465
  end_time =time.time()
466
  print("cost: ",end_time-start_time)
@@ -468,6 +468,11 @@ def generate_audio(video_file, caption: str):
468
  # Note: work_dir is NOT deleted here so Gradio can serve the output file.
469
  # Gradio manages its own GRADIO_TEMP_DIR cleanup on restart.
470
 
 
 
 
 
 
471
 
472
  # ==================== Gradio UI ====================
473
 
@@ -622,7 +627,7 @@ if __name__ == "__main__":
622
  log.info("βœ… All model files found.")
623
 
624
  # ⭐ Load all models once at startup
625
- load_all_models()
626
 
627
  demo = build_ui()
628
  demo.queue(max_size=3)
 
345
 
346
  # ==================== Full Inference Pipeline ====================
347
 
348
+ @spaces.GPU
349
+ def generate_audio_core(video_file, caption):
350
+ if _MODELS["diffusion"] is None:
351
+ load_all_models()
352
+
353
  start_time =time.time()
354
 
355
  """
 
360
  """
361
  # ---- Basic validation ----
362
  if video_file is None:
363
+ return "❌ Please upload a video file first.", None
364
+
365
  if not caption or caption.strip() == "":
366
+ caption=""
367
+
368
 
369
  caption = caption.strip()
370
  logs = []
 
379
 
380
  try:
381
  # ---- Step 1: Convert / copy to mp4 ----
382
+ #status = log_step("πŸ“Ή Step 1: Preparing video...")
 
 
383
 
384
  src_ext = os.path.splitext(video_file)[1].lower()
385
  mp4_path = os.path.join(work_dir, "input.mp4")
 
388
  log_step(" Converting to mp4...")
389
  ok, err = convert_to_mp4(video_file, mp4_path)
390
  if not ok:
391
+ return log_step(f"❌ Video conversion failed:\n{err}"), None
 
392
  else:
393
  shutil.copy(video_file, mp4_path)
394
  log_step(" Video ready.")
395
 
396
  # ---- Step 2: Validate duration ----
397
+ #status = log_step("πŸ“Ή Step 2: Checking video duration...")
398
+
399
 
400
  duration = get_video_duration(mp4_path)
401
  log_step(f" Duration: {duration:.2f}s")
402
 
403
  # ---- Step 3: Extract video frames ----
404
+ #status = log_step("🎞️ Step 3: Extracting video frames (clip & sync)...")
 
405
 
406
  clip_chunk, sync_chunk, duration = extract_video_frames(mp4_path)
407
  log_step(f" clip_chunk : {tuple(clip_chunk.shape)}")
408
  log_step(f" sync_chunk : {tuple(sync_chunk.shape)}")
409
 
410
  # ---- Step 4: Extract model features ----
411
+ #status = log_step("🧠 Step 4: Extracting text / video / sync features...")
412
+ #yield status, None
413
 
414
  info = extract_features(clip_chunk, sync_chunk, caption)
415
  log_step(f" text_features : {tuple(info['text_features'].shape)}")
 
419
  log_step(f" sync_features : {tuple(info['sync_features'].shape)}")
420
 
421
  # ---- Step 5: Build inference batch ----
422
+ #status = log_step("πŸ“¦ Step 5: Building inference batch...")
423
+ #yield status, None
424
 
425
  audio_latent, meta = build_meta(info, duration, caption)
426
  log_step(f" audio_latent : {tuple(audio_latent.shape)}")
427
 
428
  # ---- Step 6: Diffusion sampling ----
429
+ #status = log_step("🎡 Step 6: Running diffusion sampling...")
430
+ #yield status, None
431
 
432
  generated_audio = run_diffusion(audio_latent, meta, duration)
433
  log_step(f" Generated audio shape : {tuple(generated_audio.shape)}")
434
 
435
  # ---- Step 7: Save generated audio (temp) ----
436
+ #status = log_step("πŸ’Ύ Step 7: Saving generated audio...")
437
+ #yield status, None
438
 
439
  audio_path = os.path.join(work_dir, "generated_audio.wav")
440
  torchaudio.save(
 
445
  log_step(f" Audio saved: {audio_path}")
446
 
447
  # ---- Step 8: Mux audio into original video ----
448
+ #status = log_step("🎬 Step 8: Merging audio into video...")
449
+ #yield status, None
450
 
451
  combined_path = os.path.join(work_dir, "output_with_audio.mp4")
452
  ok, err = combine_audio_video(mp4_path, audio_path, combined_path)
453
  if not ok:
454
+ return log_step(f"❌ Failed to combine audio and video:\n{err}"), None
455
+
456
 
457
  log_step("βœ… Done! Audio and video merged successfully.")
458
+ return "\n".join(logs), combined_path
459
 
460
  except Exception as e:
461
  log_step(f"❌ Unexpected error: {str(e)}")
462
  log.exception(e)
463
+ return "\n".join(logs), None
464
 
465
  end_time =time.time()
466
  print("cost: ",end_time-start_time)
 
468
  # Note: work_dir is NOT deleted here so Gradio can serve the output file.
469
  # Gradio manages its own GRADIO_TEMP_DIR cleanup on restart.
470
 
471
+ def generate_audio(video_file, caption):
472
+ # ε…ˆyieldηŠΆζ€
473
+ yield "⏳ Waiting for GPU...", None
474
+ result_logs, result_video = generate_audio_core(video_file, caption)
475
+ yield result_logs, result_video
476
 
477
  # ==================== Gradio UI ====================
478
 
 
627
  log.info("βœ… All model files found.")
628
 
629
  # ⭐ Load all models once at startup
630
+ #load_all_models()
631
 
632
  demo = build_ui()
633
  demo.queue(max_size=3)