import os # ⭐ Must be set before importing gradio import subprocess import sys if os.environ.get("SETUP_DONE") != "1": subprocess.run(["bash", "setup.sh"], check=True) os.environ["SETUP_DONE"] = "1" os.execv(sys.executable, [sys.executable] + sys.argv) import spaces os.environ["JAX_PLATFORMS"] = "cpu" import gradio as gr import logging import sys import json import torch import torchaudio import numpy as np import tempfile import shutil import subprocess from pathlib import Path import torch.nn.functional as F import mediapy from torio.io import StreamingMediaDecoder from torchvision.transforms import v2 import time import random seed=42 random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) try: from moviepy import VideoFileClip except ImportError: from moviepy.editor import VideoFileClip # ==================== Logging ==================== logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') log = logging.getLogger() # ==================== Constants ==================== _CLIP_FPS = 4 _CLIP_SIZE = 288 _SYNC_FPS = 25 _SYNC_SIZE = 224 SAMPLE_RATE = 44100 # ==================== Model Path Configuration ==================== from huggingface_hub import snapshot_download snapshot_download(repo_id="FunAudioLLM/PrismAudio", local_dir="./ckpts") MODEL_CONFIG_PATH = "PrismAudio/configs/model_configs/prismaudio.json" CKPT_PATH = "ckpts/prismaudio.ckpt" VAE_CKPT_PATH = "ckpts/vae.ckpt" VAE_CONFIG_PATH = "PrismAudio/configs/model_configs/stable_audio_2_0_vae.json" SYNCHFORMER_CKPT_PATH = "ckpts/synchformer_state_dict.pth" DEVICE = 'cpu' # 启动时用CPU # ==================== Global Model Registry ==================== _MODELS = { "feature_extractor": None, "diffusion": None, "model_config": None, "sync_transform": None, } def load_all_models(): """Load all models once at application startup.""" global _MODELS log.info("=" * 50) log.info("Loading all models...") # ---- 1. Sync video transform ---- _MODELS["sync_transform"] = v2.Compose([ v2.Resize(_SYNC_SIZE, interpolation=v2.InterpolationMode.BICUBIC), v2.CenterCrop(_SYNC_SIZE), v2.ToImage(), v2.ToDtype(torch.float32, scale=True), v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), ]) log.info("✅ sync_transform ready") # ---- 2. FeaturesUtils ---- from data_utils.v2a_utils.feature_utils_288 import FeaturesUtils feature_extractor = FeaturesUtils( vae_ckpt=None, vae_config=VAE_CONFIG_PATH, enable_conditions=True, synchformer_ckpt=SYNCHFORMER_CKPT_PATH, ) feature_extractor = feature_extractor.eval() _MODELS["feature_extractor"] = feature_extractor log.info("✅ FeaturesUtils loaded") # ---- 3. Diffusion model ---- from PrismAudio.models import create_model_from_config from PrismAudio.models.utils import load_ckpt_state_dict with open(MODEL_CONFIG_PATH) as f: model_config = json.load(f) _MODELS["model_config"] = model_config diffusion = create_model_from_config(model_config) diffusion.load_state_dict(torch.load(CKPT_PATH, map_location='cpu')) vae_state = load_ckpt_state_dict(VAE_CKPT_PATH, prefix='autoencoder.') diffusion.pretransform.load_state_dict(vae_state) diffusion = diffusion.eval() _MODELS["diffusion"] = diffusion log.info("✅ Diffusion model loaded") log.info("=" * 50) log.info("All models ready. Waiting for inference requests.") # ==================== Video Utilities ==================== def get_video_duration(video_path: str) -> float: video = VideoFileClip(str(video_path)) duration = video.duration video.close() return duration def convert_to_mp4(src: str, dst: str) -> tuple[bool, str]: """Re-encode any video format to h264/aac mp4 via ffmpeg.""" result = subprocess.run( [ "ffmpeg", "-y", "-i", src, "-c:v", "libx264", "-preset", "fast", "-c:a", "aac", "-strict", "experimental", dst, ], capture_output=True, text=True, ) return result.returncode == 0, result.stderr def combine_audio_video(video_path: str, audio_path: str, output_path: str) -> tuple[bool, str]: """Mux generated audio into the original silent video via ffmpeg.""" result = subprocess.run( [ "ffmpeg", "-y", "-i", video_path, "-i", audio_path, "-c:v", "copy", "-c:a", "aac", "-strict", "experimental", "-map", "0:v:0", "-map", "1:a:0", "-shortest", output_path, ], capture_output=True, text=True, ) return result.returncode == 0, result.stderr def pad_to_square(video_tensor: torch.Tensor) -> torch.Tensor: """(L, C, H, W) -> (L, C, _CLIP_SIZE, _CLIP_SIZE)""" if len(video_tensor.shape) != 4: raise ValueError("Input tensor must have shape (L, C, H, W)") l, c, h, w = video_tensor.shape max_side = max(h, w) pad_h = max_side - h pad_w = max_side - w padding = (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2) video_padded = F.pad(video_tensor, pad=padding, mode='constant', value=0) return F.interpolate( video_padded, size=(_CLIP_SIZE, _CLIP_SIZE), mode='bilinear', align_corners=False, ) def extract_video_frames(video_path: str): """ Decode clip_chunk and sync_chunk from video entirely in memory. Returns: clip_chunk : (L, H, W, C) float32 [0, 1] sync_chunk : (L, C, H, W) float32 normalized duration : float (seconds) """ sync_transform = _MODELS["sync_transform"] assert sync_transform is not None, "Call load_all_models() first." duration_sec = get_video_duration(video_path) reader = StreamingMediaDecoder(video_path) reader.add_basic_video_stream( frames_per_chunk=int(_CLIP_FPS * duration_sec), frame_rate=_CLIP_FPS, format='rgb24', ) reader.add_basic_video_stream( frames_per_chunk=int(_SYNC_FPS * duration_sec), frame_rate=_SYNC_FPS, format='rgb24', ) reader.fill_buffer() data_chunk = reader.pop_chunks() clip_chunk = data_chunk[0] sync_chunk = data_chunk[1] if clip_chunk is None: raise RuntimeError("CLIP video stream returned None") if sync_chunk is None: raise RuntimeError("Sync video stream returned None") # ---- clip_chunk ---- clip_expected = int(_CLIP_FPS * duration_sec) clip_chunk = clip_chunk[:clip_expected] if clip_chunk.shape[0] < clip_expected: pad_n = clip_expected - clip_chunk.shape[0] clip_chunk = torch.cat( [clip_chunk, clip_chunk[-1:].repeat(pad_n, 1, 1, 1)], dim=0 ) clip_chunk = pad_to_square(clip_chunk) clip_chunk = clip_chunk.permute(0, 2, 3, 1) clip_chunk = mediapy.to_float01(clip_chunk) # ---- sync_chunk ---- sync_expected = int(_SYNC_FPS * duration_sec) sync_chunk = sync_chunk[:sync_expected] if sync_chunk.shape[0] < sync_expected: pad_n = sync_expected - sync_chunk.shape[0] sync_chunk = torch.cat( [sync_chunk, sync_chunk[-1:].repeat(pad_n, 1, 1, 1)], dim=0 ) sync_chunk = sync_transform(sync_chunk) log.info(f"clip_chunk: {clip_chunk.shape}, sync_chunk: {sync_chunk.shape}") return clip_chunk, sync_chunk, duration_sec def extract_features_cpu(clip_chunk, sync_chunk, caption): model = _MODELS["feature_extractor"] info = {} with torch.no_grad(): clip_input = torch.from_numpy(clip_chunk).unsqueeze(0) video_feat, frame_embed, _, text_feat = \ model.encode_video_and_text_with_videoprism(clip_input, [caption]) info['global_video_features'] = torch.tensor(np.array(video_feat)).squeeze(0).cpu() info['video_features'] = torch.tensor(np.array(frame_embed)).squeeze(0).cpu() info['global_text_features'] = torch.tensor(np.array(text_feat)).squeeze(0).cpu() return info # ==================== Feature Extraction ==================== @spaces.GPU def extract_features_gpu(clip_chunk, sync_chunk, caption): model = _MODELS["feature_extractor"] info = {} with torch.no_grad(): model.t5.to('cuda') text_features = model.encode_t5_text([caption]) info['text_features'] = text_features[0].cpu() model.t5.to('cpu') model.synchformer.to('cuda') sync_input = sync_chunk.unsqueeze(0).to('cuda') info['sync_features'] = model.encode_video_with_sync(sync_input)[0].cpu() model.synchformer.to('cpu') return info def extract_features(clip_chunk, sync_chunk, caption): info = extract_features_cpu(clip_chunk, sync_chunk, caption) info.update(extract_features_gpu(clip_chunk, sync_chunk, caption)) return info # ==================== Build Meta ==================== def build_meta(info: dict, duration: float, caption: str): latent_length = round(SAMPLE_RATE * duration / 2048) audio_latent = torch.zeros((1, 64, latent_length), dtype=torch.float32) meta = dict(info) meta['id'] = 'demo' meta['relpath'] = 'demo.npz' meta['path'] = 'demo.npz' meta['caption_cot'] = caption meta['video_exist'] = torch.tensor(True) return audio_latent, meta # ==================== Diffusion Sampling ==================== @spaces.GPU def run_diffusion(audio_latent: torch.Tensor, meta: dict, duration: float) -> torch.Tensor: """Reuses globally loaded diffusion model — no reload per call.""" from PrismAudio.inference.sampling import sample, sample_discrete_euler import time diffusion = _MODELS["diffusion"] model_config = _MODELS["model_config"] device = 'cuda' diffusion.to("cuda") assert diffusion is not None, "Diffusion model not initialized." diffusion_objective = model_config["model"]["diffusion"]["diffusion_objective"] latent_length = round(SAMPLE_RATE * duration / 2048) meta_on_device = { k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in meta.items() } metadata = (meta_on_device,) with torch.no_grad(): with torch.amp.autocast('cuda'): conditioning = diffusion.conditioner(metadata, device) video_exist = torch.stack([item['video_exist'] for item in metadata], dim=0) if 'metaclip_features' in conditioning: conditioning['metaclip_features'][~video_exist] = \ diffusion.model.model.empty_clip_feat if 'sync_features' in conditioning: conditioning['sync_features'][~video_exist] = \ diffusion.model.model.empty_sync_feat cond_inputs = diffusion.get_conditioning_inputs(conditioning) noise = torch.randn([1, diffusion.io_channels, latent_length]).to(device) with torch.amp.autocast('cuda'): if diffusion_objective == "v": fakes = sample( diffusion.model, noise, 24, 0, **cond_inputs, cfg_scale=5, batch_cfg=True, ) elif diffusion_objective == "rectified_flow": t0 = time.time() fakes = sample_discrete_euler( diffusion.model, noise, 24, **cond_inputs, cfg_scale=5, batch_cfg=True, ) log.info(f"Sampling time: {time.time() - t0:.2f}s") if diffusion.pretransform is not None: fakes = diffusion.pretransform.decode(fakes) diffusion.to('cpu') return ( fakes.to(torch.float32) .div(torch.max(torch.abs(fakes))) .clamp(-1, 1) .mul(32767) .to(torch.int16) .cpu() ) # ==================== Full Inference Pipeline ==================== def generate_audio_core(video_file, caption): total_start_time = time.time() if video_file is None: return "❌ Please upload a video file first.", None if not caption or caption.strip() == "": caption="generate" caption = caption.strip() logs = [] def log_step(msg: str): log.info(msg) logs.append(msg) return "\n".join(logs) work_dir = tempfile.mkdtemp(prefix="PrismAudio_") try: # ---- Step 1: Convert / copy to mp4 ---- step_start = time.time() status = log_step("📹 Step 1: Preparing video...") src_ext = os.path.splitext(video_file)[1].lower() mp4_path = os.path.join(work_dir, "input.mp4") if src_ext != ".mp4": log_step(" Converting to mp4...") ok, err = convert_to_mp4(video_file, mp4_path) if not ok: return log_step(f"❌ Video conversion failed:\n{err}"), None else: shutil.copy(video_file, mp4_path) log_step(f"⏱️ Step 1 cost: {time.time() - step_start:.2f}s") # ---- Step 2: Validate duration ---- step_start = time.time() status = log_step("📹 Step 2: Checking video duration...") duration = get_video_duration(mp4_path) if duration > 15: #yield log_step(f"❌ Video duration {duration:.1f}s exceeds the 15s limit. Please upload a shorter video."), None return log_step(f"❌ Video duration {duration:.1f}s exceeds the 15s limit. Please upload a shorter video."), None log_step(f"⏱️ Step 2 cost: {time.time() - step_start:.2f}s") # ---- Step 3: Extract video frames ---- step_start = time.time() status = log_step("🎞️ Step 3: Extracting video frames...") clip_chunk, sync_chunk, duration = extract_video_frames(mp4_path) log_step(f"⏱️ Step 3 cost: {time.time() - step_start:.2f}s") # ---- Step 4: Extract model features ---- step_start = time.time() status = log_step("🧠 Step 4: Extracting text / video features...") info = extract_features(clip_chunk, sync_chunk, caption) log_step(f"⏱️ Step 4 cost: {time.time() - step_start:.2f}s") # ---- Step 5: Build inference batch ---- step_start = time.time() status = log_step("📦 Step 5: Building inference batch...") audio_latent, meta = build_meta(info, duration, caption) log_step(f"⏱️ Step 5 cost: {time.time() - step_start:.2f}s") # ---- Step 6: Diffusion sampling ---- step_start = time.time() status = log_step("🎵 Step 6: Running diffusion sampling...") generated_audio = run_diffusion(audio_latent, meta, duration) log_step(f"⏱️ Step 6 cost: {time.time() - step_start:.2f}s") # ---- Step 7: Save generated audio (temp) ---- step_start = time.time() status = log_step("💾 Step 7: Saving generated audio...") audio_path = os.path.join(work_dir, "generated_audio.wav") torchaudio.save( audio_path, generated_audio[0], # (1, T) SAMPLE_RATE, ) log_step(f"⏱️ Step 7 cost: {time.time() - step_start:.2f}s") # ---- Step 8: Mux audio into original video ---- step_start = time.time() status = log_step("🎬 Step 8: Merging audio into video...") combined_path = os.path.join(work_dir, "output_with_audio.mp4") ok, err = combine_audio_video(mp4_path, audio_path, combined_path) if not ok: return log_step(f"❌ Failed to combine audio and video:\n{err}"), None log_step(f"⏱️ Step 8 cost: {time.time() - step_start:.2f}s") total_cost = time.time() - total_start_time log_step(f"✅ Done! Audio and video merged successfully. ⏱️ Total cost: {total_cost:.2f}s") return "\n".join(logs), combined_path except Exception as e: log_step(f"❌ Unexpected error: {str(e)}") log.exception(e) return "\n".join(logs), None def generate_audio(video_file, caption): yield "⏳ Waiting for GPU...", None result_logs, result_video = generate_audio_core(video_file, caption) yield result_logs, result_video # ==================== Gradio UI ==================== def build_ui() -> gr.Blocks: with gr.Blocks( title="PrismAudio - Video to Audio Generation", theme=gr.themes.Soft(), css=""" .title { text-align:center; font-size:2em; font-weight:bold; margin-bottom:.2em; } .sub { text-align:center; color:#666; margin-bottom:1.5em; } .mono { font-family:monospace; font-size:.85em; } """, ) as demo: gr.HTML('