| import warnings |
| import spaces |
| warnings.filterwarnings("ignore") |
| import logging |
| from argparse import ArgumentParser |
| from pathlib import Path |
| import torch |
| import torchaudio |
| import gradio as gr |
| from hydra import compose, initialize |
| from huggingface_hub import snapshot_download |
| import numpy as np |
| import json |
| from datetime import datetime |
| import gc |
| import soundfile as sf |
|
|
| from resonate.eval_utils import generate_fm, setup_eval_logging |
| from resonate.model.flow_matching import FlowMatching |
| from resonate.model.networks import FluxAudio, get_model |
| from resonate.model.utils.features_utils import FeaturesUtils |
| from resonate.model.sequence_config import CONFIG_16K, CONFIG_44K |
|
|
| torch.backends.cuda.matmul.allow_tf32 = True |
| torch.backends.cudnn.allow_tf32 = True |
|
|
| log = logging.getLogger() |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| setup_eval_logging() |
|
|
| OUTPUT_DIR = Path("./output/gradio") |
| OUTPUT_DIR.mkdir(parents=True, exist_ok=True) |
| NUM_SAMPLE = 1 |
|
|
| |
| FEEDBACK_DIR = Path("./rlhf") |
| FEEDBACK_DIR.mkdir(exist_ok=True) |
| FEEDBACK_FILE = FEEDBACK_DIR / "user_preferences.jsonl" |
|
|
| |
| MODEL_CACHE = {} |
| FEATURE_UTILS_CACHE = {} |
| GLOBAL_CFG = {} |
|
|
| def fade_out(x: torch.Tensor, sr: int, fade_ms: int = 50): |
| """Apply a linear fade-out to the end of the audio tensor.""" |
| if x.dim() == 1: |
| x = x.unsqueeze(0) |
| n = x.shape[-1] |
| k = int(sr * fade_ms / 1000) |
| if k <= 0 or k >= n: |
| return x |
| w = torch.linspace(1.0, 0.0, k, device=x.device) |
| x[..., -k:] = x[..., -k:] * w |
| return x |
|
|
| def ensure_models_downloaded(): |
| model_path = Path('./weights/Resonate_GRPO.pth') |
| if not model_path.exists(): |
| log.info(f'Model not found at {model_path}') |
| log.info('Downloading models to "./weights/"...') |
| try: |
| weights_dir = Path('./weights') |
| weights_dir.mkdir(exist_ok=True) |
| snapshot_download(repo_id="AndreasXi/resonate", local_dir="./weights") |
| except Exception as e: |
| log.error(f"Failed to download model: {e}") |
| raise FileNotFoundError(f"Model file not found and download failed.") |
|
|
| def load_model_cache(): |
| if 'default' in MODEL_CACHE: |
| return |
| |
| log.info("Loading Hydra config and initializing models...") |
| with initialize(version_base="1.3.2", config_path="config"): |
| cfg = compose(config_name='GRPO_flant5_44kMMVAE_fluxaudio_audiocaps_qwen25omni_semantic') |
| GLOBAL_CFG['cfg'] = cfg |
|
|
| use_rope = cfg.get('use_rope', True) |
| text_dim = cfg.get('text_dim', None) |
| text_c_dim = cfg.get('text_c_dim', None) |
| dtype = torch.bfloat16 |
|
|
| net: FluxAudio = get_model(cfg.model, |
| use_rope=use_rope, |
| text_dim=text_dim, |
| text_c_dim=text_c_dim).to(device, dtype).eval() |
| |
| model_path = Path('./weights/Resonate_GRPO.pth') |
| net.load_weights(torch.load(model_path, map_location=device, weights_only=True)) |
| MODEL_CACHE['default'] = net |
| log.info(f'Loaded weights from {model_path}') |
|
|
| encoder_name = cfg.get('text_encoder_name', 'flan-t5') |
| if cfg.audio_sample_rate == 16000: |
| feature_utils = FeaturesUtils(tod_vae_ckpt=cfg.get('vae_16k_ckpt'), |
| enable_conditions=True, |
| encoder_name=encoder_name, |
| mode='16k', |
| bigvgan_vocoder_ckpt=cfg.get('bigvgan_vocoder_ckpt'), |
| need_vae_encoder=True) |
| elif cfg.audio_sample_rate == 44100: |
| feature_utils = FeaturesUtils(tod_vae_ckpt=cfg.get('vae_44k_ckpt'), |
| enable_conditions=True, |
| encoder_name=encoder_name, |
| mode='44k', |
| need_vae_encoder=True) |
| else: |
| raise ValueError(f'Invalid audio sample rate: {cfg.audio_sample_rate}') |
| |
| feature_utils = feature_utils.to(device, dtype).eval() |
| FEATURE_UTILS_CACHE['default'] = feature_utils |
| log.info("Model and FeatureUtils loaded successfully.") |
|
|
| def save_preference_feedback(prompt, audio1_path, audio2_path, preference, additional_comment=""): |
| feedback_data = { |
| "timestamp": datetime.now().isoformat(), |
| "prompt": prompt, |
| "audio1_path": audio1_path, |
| "audio2_path": audio2_path, |
| "preference": preference, |
| "additional_comment": additional_comment |
| } |
| |
| with open(FEEDBACK_FILE, "a", encoding="utf-8") as f: |
| f.write(json.dumps(feedback_data, ensure_ascii=False) + "\n") |
| |
| log.info(f"Preference feedback saved: {preference} for prompt: '{prompt[:50]}...'") |
| return f"✅ Thanks for your feedback, preference recorded: {preference}" |
|
|
|
|
| @spaces.GPU(duration=60) |
| @torch.inference_mode() |
| def generate_audio_gradio( |
| prompt, |
| negative_prompt, |
| duration, |
| cfg_strength, |
| num_steps, |
| seed |
| ): |
| if duration <= 0 or num_steps <= 0: |
| raise ValueError("Duration and number of steps must be positive.") |
| if not prompt.strip(): |
| prompt = "A dog is barking" |
| |
| net = MODEL_CACHE['default'] |
| feature_utils = FEATURE_UTILS_CACHE['default'] |
| cfg = GLOBAL_CFG['cfg'] |
|
|
| if cfg.audio_sample_rate == 16000: |
| seq_cfg = CONFIG_16K |
| else: |
| seq_cfg = CONFIG_44K |
|
|
| seq_cfg.duration = duration |
| net.update_seq_lengths(seq_cfg.latent_seq_len) |
|
|
| fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps) |
|
|
| rng = torch.Generator(device=device) |
| rng.manual_seed(seed) |
|
|
| log.info(f'Generating with Prompt: "{prompt}", Negative Prompt: "{negative_prompt}"') |
|
|
| audios = generate_fm( |
| [prompt] * NUM_SAMPLE, |
| negative_text=[negative_prompt] * NUM_SAMPLE, |
| feature_utils=feature_utils, |
| net=net, |
| fm=fm, |
| rng=rng, |
| cfg_strength=cfg_strength, |
| ) |
| |
| save_paths = [] |
| safe_prompt = ( |
| "".join(c for c in prompt if c.isalnum() or c in (" ", "_")) |
| .rstrip() |
| .replace(" ", "_")[:50] |
| ) |
| |
| for i, audio in enumerate(audios): |
| audio = audio.float().cpu() |
| audio = fade_out(audio, seq_cfg.sampling_rate, fade_ms=100) |
|
|
| current_time_string = datetime.now().strftime("%Y%m%d_%H%M%S_%f") |
| filename = f"{safe_prompt}_{current_time_string}_{i}.flac" |
| save_path = OUTPUT_DIR / filename |
| |
| sf.write(str(save_path), audio.squeeze().cpu().numpy(), seq_cfg.sampling_rate) |
| log.info(f"Audio saved to {save_path}") |
| save_paths.append(str(save_path)) |
|
|
| if device == "cuda": |
| torch.cuda.empty_cache() |
|
|
| return save_paths[0], prompt |
|
|
|
|
| |
| input_text = gr.Textbox(lines=2, label="Prompt", placeholder="Describe the audio you want to generate...") |
| negative_prompt = gr.Textbox(lines=1, label="Negative Prompt", placeholder="Elements you want to avoid...") |
| output_audio = gr.Audio(label="Generated Audio", type="filepath") |
| denoising_steps = gr.Slider(minimum=1, maximum=100, value=50, step=1, label="Sampling Steps", interactive=True) |
| cfg_strength = gr.Slider(minimum=1, maximum=10, value=4.5, step=1, label="Guidance Scale (CFG)", interactive=True) |
| duration = gr.Slider(minimum=1, maximum=30, value=10, step=1, label="Duration (seconds)", interactive=True) |
| seed = gr.Slider(minimum=1, maximum=10000, value=123, step=1, label="Seed", interactive=True) |
|
|
|
|
| description_text = """ |
| ### **Resonate** is a novel text-to-audio generation model. |
| ### [📖 **Arxiv**](https://arxiv.org/abs/2603.11661) | [💻 **GitHub**](https://github.com/xiquan-li/Resonate) | [🤗 **Model**](https://huggingface.co/AndreasXi/resonate) | [🌐 **Project Page**](https://resonatedemo.github.io/) |
| """ |
|
|
| gr_interface = gr.Interface( |
| fn=generate_audio_gradio, |
| inputs=[input_text, negative_prompt, duration, cfg_strength, denoising_steps, seed], |
| outputs=[ |
| gr.Audio(label="🎵 Audio Sample", type="filepath"), |
| gr.Textbox(label="Prompt Used", interactive=False) |
| ], |
| title="Resonate: Reinforcing Text-to-Audio Generation with Online Feedbacks from Large Audio Language Models", |
| description=description_text, |
| flagging_mode="never", |
| examples=[ |
| |
| ["Rain falls onto a hard surface", "", 10, 4.5, 50, 123], |
| ["The microwave beeps three times, followed by the creak of its door opening and the hiss of steam escaping.", "", 10, 4.5, 50, 123], |
| ["The audience clapping rhythmically followed by a guitarist beginning an acoustic set.", "", 10, 4.5, 50, 123], |
| ["A security tag alarm blares, stops abruptly, then nervous laughter follows.", "", 10, 4.5, 50, 123], |
| ["Horse clip-clopping and heavy wind", "", 10, 4.5, 50, 123], |
| ["Camera muffling followed by a person whistling then plastic clacking as birds chirp in the background", "", 10, 4.5, 50, 123], |
| ["A girl speaks, rustling of a camera followed by computer typing", "", 10, 4.5, 50, 123], |
| ["First, the mellow strum of an acoustic guitar fills the air, followed by the sharp hiss of a spray can releasing vibrant colors.", "", 10, 4.5, 50, 123], |
| ["The rhythmic ticking of an analog wall clock marks time above the stove.", "", 10, 4.5, 50, 123], |
| ["Simultaneous chatter, laughter, and clinking utensils create a lively dinner ambiance.", "", 10, 4.5, 50, 123], |
| ["A car door slams, engine revs, and tires screech away.", "", 10, 4.5, 50, 123], |
| ["Simultaneous sounds of typing, paper rustling, clock ticking, and distant traffic create a work ambiance.", "", 10, 4.5, 50, 123], |
| ["Melodic human whistling harmonizing with natural birdsong", "", 10, 4.5, 50, 123], |
| ["A parade marches through a town square, with drumbeats pounding, children clapping, and a horse neighing amidst the commotion", "", 10, 4.5, 50, 123], |
| ["Quiet speech and then and airplane flying away", "", 10, 4.5, 50, 123], |
| ["Battlefield scene, continuous roar of artillery and gunfire, high fidelity, the sharp crack of bullets, the thundering explosions of bombs", "", 10, 4.5, 50, 123], |
| ], |
| cache_examples="lazy", |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| ensure_models_downloaded() |
| load_model_cache() |
| gr_interface.queue(max_size=15).launch() |