| import spaces |
| import gradio as gr |
| import torch |
| import torchaudio |
| import os |
| from einops import rearrange |
| import gc |
| import spaces |
| import gradio as gr |
| import torch |
| import torchaudio |
| import os |
| from einops import rearrange |
| from stable_audio_tools import get_pretrained_model |
| from stable_audio_tools.inference.generation import generate_diffusion_cond |
| from stable_audio_tools.data.utils import read_video, merge_video_audio, load_and_process_audio |
| import stat |
| import platform |
| import logging |
| from transformers import logging as transformers_logging |
|
|
| transformers_logging.set_verbosity_error() |
| logging.getLogger("transformers").setLevel(logging.ERROR) |
|
|
| model, model_config = get_pretrained_model('HKUSTAudio/AudioX') |
| sample_rate = model_config["sample_rate"] |
| sample_size = model_config["sample_size"] |
|
|
| TEMP_DIR = "tmp/gradio" |
| os.makedirs(TEMP_DIR, exist_ok=True) |
| os.chmod(TEMP_DIR, stat.S_IRWXU | stat.S_IRWXG | stat.S_IRWXO) |
|
|
| VIDEO_TEMP_DIR = os.path.join(TEMP_DIR, "videos") |
| os.makedirs(VIDEO_TEMP_DIR, exist_ok=True) |
| os.chmod(VIDEO_TEMP_DIR, stat.S_IRWXU | stat.S_IRWXG | stat.S_IRWXO) |
|
|
|
|
|
|
| @spaces.GPU(duration=10) |
| def generate_cond( |
| prompt, |
| negative_prompt=None, |
| video_file=None, |
| audio_prompt_file=None, |
| audio_prompt_path=None, |
| seconds_start=0, |
| seconds_total=10, |
| cfg_scale=7.0, |
| steps=100, |
| preview_every=0, |
| seed=-1, |
| sampler_type="dpmpp-3m-sde", |
| sigma_min=0.03, |
| sigma_max=500, |
| cfg_rescale=0.0, |
| use_init=False, |
| init_audio=None, |
| init_noise_level=0.1, |
| mask_cropfrom=None, |
| mask_pastefrom=None, |
| mask_pasteto=None, |
| mask_maskstart=None, |
| mask_maskend=None, |
| mask_softnessL=None, |
| mask_softnessR=None, |
| mask_marination=None, |
| batch_size=1 |
| ): |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
| gc.collect() |
| print(f"Prompt: {prompt}") |
| preview_images = [] |
| if preview_every == 0: |
| preview_every = None |
|
|
| try: |
| has_mps = platform.system() == "Darwin" and torch.backends.mps.is_available() |
| except Exception: |
| has_mps = False |
| if has_mps: |
| device = torch.device("mps") |
| elif torch.cuda.is_available(): |
| device = torch.device("cuda") |
| else: |
| device = torch.device("cpu") |
| |
| global model |
| model = model.to(device) |
|
|
| target_fps = model_config.get("video_fps", 5) |
| model_type = model_config.get("model_type", "diffusion_cond") |
|
|
| if video_file is not None: |
| actual_video_path = video_file['name'] if isinstance(video_file, dict) else video_file.name |
| else: |
| actual_video_path = None |
|
|
| if audio_prompt_file is not None: |
| audio_path = audio_prompt_file.name |
| elif audio_prompt_path: |
| audio_path = audio_prompt_path.strip() |
| else: |
| audio_path = None |
|
|
| Video_tensors = read_video(actual_video_path, seek_time=seconds_start, duration=seconds_total, target_fps=target_fps) |
| audio_tensor = load_and_process_audio(audio_path, sample_rate, seconds_start, seconds_total) |
|
|
| audio_tensor = audio_tensor.to(device) |
| seconds_input = sample_size / sample_rate |
| |
| if not prompt: |
| prompt = "" |
|
|
| conditioning = [{ |
| "video_prompt": [Video_tensors.unsqueeze(0)], |
| "text_prompt": prompt, |
| "audio_prompt": audio_tensor.unsqueeze(0), |
| "seconds_start": seconds_start, |
| "seconds_total": seconds_input |
| }] |
| if negative_prompt: |
| negative_conditioning = [{ |
| "video_prompt": [Video_tensors.unsqueeze(0)], |
| "text_prompt": negative_prompt, |
| "audio_prompt": audio_tensor.unsqueeze(0), |
| "seconds_start": seconds_start, |
| "seconds_total": seconds_total |
| }] * 1 |
| else: |
| negative_conditioning = None |
|
|
| seed = int(seed) |
| if not use_init: |
| init_audio = None |
| input_sample_size = sample_size |
|
|
| def progress_callback(callback_info): |
| nonlocal preview_images |
| denoised = callback_info["denoised"] |
| current_step = callback_info["i"] |
| sigma = callback_info["sigma"] |
| if (current_step - 1) % preview_every == 0: |
| if model.pretransform is not None: |
| denoised = model.pretransform.decode(denoised) |
| denoised = rearrange(denoised, "b d n -> d (b n)") |
| denoised = denoised.clamp(-1, 1).mul(32767).to(torch.int16).cpu() |
| audio_spectrogram = audio_spectrogram_image(denoised, sample_rate=sample_rate) |
| preview_images.append((audio_spectrogram, f"Step {current_step} sigma={sigma:.3f})")) |
| |
| if model_type == "diffusion_cond": |
| audio = generate_diffusion_cond( |
| model, |
| conditioning=conditioning, |
| negative_conditioning=negative_conditioning, |
| steps=steps, |
| cfg_scale=cfg_scale, |
| batch_size=batch_size, |
| sample_size=input_sample_size, |
| sample_rate=sample_rate, |
| seed=seed, |
| device=device, |
| sampler_type=sampler_type, |
| sigma_min=sigma_min, |
| sigma_max=sigma_max, |
| init_audio=init_audio, |
| init_noise_level=init_noise_level, |
| mask_args=None, |
| callback=progress_callback if preview_every is not None else None, |
| scale_phi=cfg_rescale |
| ) |
|
|
| audio = rearrange(audio, "b d n -> d (b n)") |
|
|
| samples_10s = 10 * sample_rate |
| audio = audio[:, :samples_10s] |
| audio = audio.to(torch.float32).div(torch.max(torch.abs(audio))).clamp(-1, 1).mul(32767).to(torch.int16).cpu() |
|
|
| output_dir = "demo_result" |
| os.makedirs(output_dir, exist_ok=True) |
| output_audio_path = f"{output_dir}/output.wav" |
| torchaudio.save(output_audio_path, audio, sample_rate) |
|
|
| if actual_video_path: |
| output_video_path = f"{output_dir}/{os.path.basename(actual_video_path)}" |
| target_width = 1280 |
| target_height = 720 |
| merge_video_audio( |
| actual_video_path, |
| output_audio_path, |
| output_video_path, |
| seconds_start, |
| seconds_total |
| ) |
| else: |
| output_video_path = None |
|
|
| del actual_video_path |
| torch.cuda.empty_cache() |
| gc.collect() |
|
|
| return output_video_path, output_audio_path |
|
|
|
|
| with gr.Blocks() as interface: |
| gr.Markdown( |
| """ |
| # 🎧AudioX: Diffusion Transformer for Anything-to-Audio Generation |
| **[Paper](https://arxiv.org/abs/2503.10522) · [Project Page](https://zeyuet.github.io/AudioX/) · [Huggingface](https://huggingface.co/HKUSTAudio/AudioX) · [GitHub](https://github.com/ZeyueT/AudioX)** |
| """ |
| ) |
|
|
| with gr.Tab("Generation"): |
| with gr.Row(): |
| with gr.Column(): |
| prompt = gr.Textbox( |
| show_label=False, |
| placeholder="Enter your prompt" |
| ) |
| negative_prompt = gr.Textbox( |
| show_label=False, |
| placeholder="Negative prompt", |
| visible=False |
| ) |
| video_file = gr.File(label="Upload Video File") |
| audio_prompt_file = gr.File( |
| label="Upload Audio Prompt File", |
| visible=False |
| ) |
| audio_prompt_path = gr.Textbox( |
| label="Audio Prompt Path", |
| placeholder="Enter audio file path", |
| visible=False |
| ) |
|
|
| with gr.Row(): |
| with gr.Column(scale=6): |
| with gr.Accordion("Video Params", open=False): |
| seconds_start = gr.Slider( |
| minimum=0, |
| maximum=512, |
| step=1, |
| value=0, |
| label="Video Seconds Start" |
| ) |
| seconds_total = gr.Slider( |
| minimum=0, |
| maximum=10, |
| step=1, |
| value=10, |
| label="Seconds Total", |
| interactive=False |
| ) |
|
|
| with gr.Row(): |
| with gr.Column(scale=4): |
| with gr.Accordion("Sampler Params", open=False): |
| steps = gr.Slider( |
| minimum=1, |
| maximum=500, |
| step=1, |
| value=100, |
| label="Steps" |
| ) |
| preview_every = gr.Slider( |
| minimum=0, |
| maximum=100, |
| step=1, |
| value=0, |
| label="Preview Every" |
| ) |
| cfg_scale = gr.Slider( |
| minimum=0.0, |
| maximum=25.0, |
| step=0.1, |
| value=7.0, |
| label="CFG Scale" |
| ) |
| seed = gr.Textbox( |
| label="Seed (set to -1 for random seed)", |
| value="-1" |
| ) |
| sampler_type = gr.Dropdown( |
| choices=[ |
| "dpmpp-2m-sde", |
| "dpmpp-3m-sde", |
| "k-heun", |
| "k-lms", |
| "k-dpmpp-2s-ancestral", |
| "k-dpm-2", |
| "k-dpm-fast" |
| ], |
| label="Sampler Type", |
| value="dpmpp-3m-sde" |
| ) |
| sigma_min = gr.Slider( |
| minimum=0.0, |
| maximum=2.0, |
| step=0.01, |
| value=0.03, |
| label="Sigma Min" |
| ) |
| sigma_max = gr.Slider( |
| minimum=0.0, |
| maximum=1000.0, |
| step=0.1, |
| value=500, |
| label="Sigma Max" |
| ) |
| cfg_rescale = gr.Slider( |
| minimum=0.0, |
| maximum=1, |
| step=0.01, |
| value=0.0, |
| label="CFG Rescale Amount" |
| ) |
|
|
| with gr.Row(): |
| with gr.Column(scale=4): |
| with gr.Accordion("Init Audio", open=False, visible=False): |
| init_audio_checkbox = gr.Checkbox(label="Use Init Audio") |
| init_audio_input = gr.Audio(label="Init Audio") |
| init_noise_level = gr.Slider( |
| minimum=0.1, |
| maximum=100.0, |
| step=0.01, |
| value=0.1, |
| label="Init Noise Level" |
| ) |
|
|
| with gr.Row(): |
| generate_button = gr.Button("Generate", variant="primary") |
|
|
| with gr.Row(): |
| with gr.Column(scale=6): |
| video_output = gr.Video(label="Output Video", interactive=False) |
| audio_output = gr.Audio(label="Output Audio", interactive=False) |
|
|
| inputs = [ |
| prompt, |
| negative_prompt, |
| video_file, |
| audio_prompt_file, |
| audio_prompt_path, |
| seconds_start, |
| seconds_total, |
| cfg_scale, |
| steps, |
| preview_every, |
| seed, |
| sampler_type, |
| sigma_min, |
| sigma_max, |
| cfg_rescale, |
| init_audio_checkbox, |
| init_audio_input, |
| init_noise_level |
| ] |
|
|
| generate_button.click( |
| fn=generate_cond, |
| inputs=inputs, |
| outputs=[video_output, audio_output] |
| ) |
|
|
| gr.Markdown("## Examples") |
| with gr.Accordion("Click to show examples", open=False): |
| with gr.Row(): |
| gr.Markdown("**📝 Task: Text-to-Audio**") |
| with gr.Column(scale=1.2): |
| gr.Markdown("Prompt: *Typing on a keyboard*") |
| ex1 = gr.Button("Load Example") |
| with gr.Column(scale=1.2): |
| gr.Markdown("Prompt: *Ocean waves crashing*") |
| ex2 = gr.Button("Load Example") |
| with gr.Column(scale=1.2): |
| gr.Markdown("Prompt: *Footsteps in snow*") |
| ex3 = gr.Button("Load Example") |
| |
| with gr.Row(): |
| gr.Markdown("**🎶 Task: Text-to-Music**") |
| with gr.Column(scale=1.2): |
| gr.Markdown("Prompt: *An orchestral music piece for a fantasy world.*") |
| ex4 = gr.Button("Load Example") |
| with gr.Column(scale=1.2): |
| gr.Markdown("Prompt: *Produce upbeat electronic music for a dance party*") |
| ex5 = gr.Button("Load Example") |
| with gr.Column(scale=1.2): |
| gr.Markdown("Prompt: *A dreamy lo-fi beat with vinyl crackle*") |
| ex6 = gr.Button("Load Example") |
|
|
| ex1.click(lambda: ["Typing on a keyboard", None, None, None, None, 0, 10, 7.0, 100, 0, "1225575558", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs) |
| ex2.click(lambda: ["Ocean waves crashing", None, None, None, None, 0, 10, 7.0, 100, 0, "3615819170", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs) |
| ex3.click(lambda: ["Footsteps in snow", None, None, None, None, 0, 10, 7.0, 100, 0, "1703896811", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs) |
| ex4.click(lambda: ["An orchestral music piece for a fantasy world.", None, None, None, None, 0, 10, 7.0, 100, 0, "1561898939", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs) |
| ex5.click(lambda: ["Produce upbeat electronic music for a dance party", None, None, None, None, 0, 10, 7.0, 100, 0, "406022999", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs) |
| ex6.click(lambda: ["A dreamy lo-fi beat with vinyl crackle", None, None, None, None, 0, 10, 7.0, 100, 0, "807934770", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs) |
|
|
| interface.queue(5).launch(server_name="0.0.0.0", server_port=7860, share=True) |