Resonate / app.py
AndreasXi's picture
add paper link
42e6049
import warnings
import spaces
warnings.filterwarnings("ignore")
import logging
from argparse import ArgumentParser
from pathlib import Path
import torch
import torchaudio
import gradio as gr
from hydra import compose, initialize
from huggingface_hub import snapshot_download
import numpy as np
import json
from datetime import datetime
import gc
import soundfile as sf
from resonate.eval_utils import generate_fm, setup_eval_logging
from resonate.model.flow_matching import FlowMatching
from resonate.model.networks import FluxAudio, get_model
from resonate.model.utils.features_utils import FeaturesUtils
from resonate.model.sequence_config import CONFIG_16K, CONFIG_44K
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
log = logging.getLogger()
device = "cuda" if torch.cuda.is_available() else "cpu"
setup_eval_logging()
OUTPUT_DIR = Path("./output/gradio")
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
NUM_SAMPLE = 1
# 创建RLHF反馈数据目录
FEEDBACK_DIR = Path("./rlhf")
FEEDBACK_DIR.mkdir(exist_ok=True)
FEEDBACK_FILE = FEEDBACK_DIR / "user_preferences.jsonl"
# Global cache to avoid reloading
MODEL_CACHE = {}
FEATURE_UTILS_CACHE = {}
GLOBAL_CFG = {}
def fade_out(x: torch.Tensor, sr: int, fade_ms: int = 50):
"""Apply a linear fade-out to the end of the audio tensor."""
if x.dim() == 1:
x = x.unsqueeze(0)
n = x.shape[-1]
k = int(sr * fade_ms / 1000)
if k <= 0 or k >= n:
return x
w = torch.linspace(1.0, 0.0, k, device=x.device)
x[..., -k:] = x[..., -k:] * w
return x
def ensure_models_downloaded():
model_path = Path('./weights/Resonate_GRPO.pth')
if not model_path.exists():
log.info(f'Model not found at {model_path}')
log.info('Downloading models to "./weights/"...')
try:
weights_dir = Path('./weights')
weights_dir.mkdir(exist_ok=True)
snapshot_download(repo_id="AndreasXi/resonate", local_dir="./weights")
except Exception as e:
log.error(f"Failed to download model: {e}")
raise FileNotFoundError(f"Model file not found and download failed.")
def load_model_cache():
if 'default' in MODEL_CACHE:
return
log.info("Loading Hydra config and initializing models...")
with initialize(version_base="1.3.2", config_path="config"):
cfg = compose(config_name='GRPO_flant5_44kMMVAE_fluxaudio_audiocaps_qwen25omni_semantic')
GLOBAL_CFG['cfg'] = cfg
use_rope = cfg.get('use_rope', True)
text_dim = cfg.get('text_dim', None)
text_c_dim = cfg.get('text_c_dim', None)
dtype = torch.bfloat16 # Use bfloat16 as default for inference
net: FluxAudio = get_model(cfg.model,
use_rope=use_rope,
text_dim=text_dim,
text_c_dim=text_c_dim).to(device, dtype).eval()
model_path = Path('./weights/Resonate_GRPO.pth')
net.load_weights(torch.load(model_path, map_location=device, weights_only=True))
MODEL_CACHE['default'] = net
log.info(f'Loaded weights from {model_path}')
encoder_name = cfg.get('text_encoder_name', 'flan-t5')
if cfg.audio_sample_rate == 16000:
feature_utils = FeaturesUtils(tod_vae_ckpt=cfg.get('vae_16k_ckpt'),
enable_conditions=True,
encoder_name=encoder_name,
mode='16k',
bigvgan_vocoder_ckpt=cfg.get('bigvgan_vocoder_ckpt'),
need_vae_encoder=True)
elif cfg.audio_sample_rate == 44100:
feature_utils = FeaturesUtils(tod_vae_ckpt=cfg.get('vae_44k_ckpt'),
enable_conditions=True,
encoder_name=encoder_name,
mode='44k',
need_vae_encoder=True)
else:
raise ValueError(f'Invalid audio sample rate: {cfg.audio_sample_rate}')
feature_utils = feature_utils.to(device, dtype).eval()
FEATURE_UTILS_CACHE['default'] = feature_utils
log.info("Model and FeatureUtils loaded successfully.")
def save_preference_feedback(prompt, audio1_path, audio2_path, preference, additional_comment=""):
feedback_data = {
"timestamp": datetime.now().isoformat(),
"prompt": prompt,
"audio1_path": audio1_path,
"audio2_path": audio2_path,
"preference": preference,
"additional_comment": additional_comment
}
with open(FEEDBACK_FILE, "a", encoding="utf-8") as f:
f.write(json.dumps(feedback_data, ensure_ascii=False) + "\n")
log.info(f"Preference feedback saved: {preference} for prompt: '{prompt[:50]}...'")
return f"✅ Thanks for your feedback, preference recorded: {preference}"
@spaces.GPU(duration=60)
@torch.inference_mode()
def generate_audio_gradio(
prompt,
negative_prompt,
duration,
cfg_strength,
num_steps,
seed
):
if duration <= 0 or num_steps <= 0:
raise ValueError("Duration and number of steps must be positive.")
if not prompt.strip():
prompt = "A dog is barking" # Default fallback
net = MODEL_CACHE['default']
feature_utils = FEATURE_UTILS_CACHE['default']
cfg = GLOBAL_CFG['cfg']
if cfg.audio_sample_rate == 16000:
seq_cfg = CONFIG_16K
else:
seq_cfg = CONFIG_44K
seq_cfg.duration = duration
net.update_seq_lengths(seq_cfg.latent_seq_len)
fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps)
rng = torch.Generator(device=device)
rng.manual_seed(seed)
log.info(f'Generating with Prompt: "{prompt}", Negative Prompt: "{negative_prompt}"')
audios = generate_fm(
[prompt] * NUM_SAMPLE,
negative_text=[negative_prompt] * NUM_SAMPLE,
feature_utils=feature_utils,
net=net,
fm=fm,
rng=rng,
cfg_strength=cfg_strength,
)
save_paths = []
safe_prompt = (
"".join(c for c in prompt if c.isalnum() or c in (" ", "_"))
.rstrip()
.replace(" ", "_")[:50]
)
for i, audio in enumerate(audios):
audio = audio.float().cpu()
audio = fade_out(audio, seq_cfg.sampling_rate, fade_ms=100)
current_time_string = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
filename = f"{safe_prompt}_{current_time_string}_{i}.flac"
save_path = OUTPUT_DIR / filename
# torchaudio.save(str(save_path), audio, seq_cfg.sampling_rate)
sf.write(str(save_path), audio.squeeze().cpu().numpy(), seq_cfg.sampling_rate)
log.info(f"Audio saved to {save_path}")
save_paths.append(str(save_path))
if device == "cuda":
torch.cuda.empty_cache()
return save_paths[0], prompt
# Gradio input and output components
input_text = gr.Textbox(lines=2, label="Prompt", placeholder="Describe the audio you want to generate...")
negative_prompt = gr.Textbox(lines=1, label="Negative Prompt", placeholder="Elements you want to avoid...")
output_audio = gr.Audio(label="Generated Audio", type="filepath")
denoising_steps = gr.Slider(minimum=1, maximum=100, value=50, step=1, label="Sampling Steps", interactive=True)
cfg_strength = gr.Slider(minimum=1, maximum=10, value=4.5, step=1, label="Guidance Scale (CFG)", interactive=True)
duration = gr.Slider(minimum=1, maximum=30, value=10, step=1, label="Duration (seconds)", interactive=True)
seed = gr.Slider(minimum=1, maximum=10000, value=123, step=1, label="Seed", interactive=True)
description_text = """
### **Resonate** is a novel text-to-audio generation model.
### [📖 **Arxiv**](https://arxiv.org/abs/2603.11661) | [💻 **GitHub**](https://github.com/xiquan-li/Resonate) | [🤗 **Model**](https://huggingface.co/AndreasXi/resonate) | [🌐 **Project Page**](https://resonatedemo.github.io/)
"""
gr_interface = gr.Interface(
fn=generate_audio_gradio,
inputs=[input_text, negative_prompt, duration, cfg_strength, denoising_steps, seed],
outputs=[
gr.Audio(label="🎵 Audio Sample", type="filepath"),
gr.Textbox(label="Prompt Used", interactive=False)
],
title="Resonate: Reinforcing Text-to-Audio Generation with Online Feedbacks from Large Audio Language Models",
description=description_text,
flagging_mode="never",
examples=[
# ["A cat is meowing, followed by guitar sound", "", 10, 4.5, 50, 123],
["Rain falls onto a hard surface", "", 10, 4.5, 50, 123],
["The microwave beeps three times, followed by the creak of its door opening and the hiss of steam escaping.", "", 10, 4.5, 50, 123],
["The audience clapping rhythmically followed by a guitarist beginning an acoustic set.", "", 10, 4.5, 50, 123],
["A security tag alarm blares, stops abruptly, then nervous laughter follows.", "", 10, 4.5, 50, 123],
["Horse clip-clopping and heavy wind", "", 10, 4.5, 50, 123],
["Camera muffling followed by a person whistling then plastic clacking as birds chirp in the background", "", 10, 4.5, 50, 123],
["A girl speaks, rustling of a camera followed by computer typing", "", 10, 4.5, 50, 123],
["First, the mellow strum of an acoustic guitar fills the air, followed by the sharp hiss of a spray can releasing vibrant colors.", "", 10, 4.5, 50, 123],
["The rhythmic ticking of an analog wall clock marks time above the stove.", "", 10, 4.5, 50, 123],
["Simultaneous chatter, laughter, and clinking utensils create a lively dinner ambiance.", "", 10, 4.5, 50, 123],
["A car door slams, engine revs, and tires screech away.", "", 10, 4.5, 50, 123],
["Simultaneous sounds of typing, paper rustling, clock ticking, and distant traffic create a work ambiance.", "", 10, 4.5, 50, 123],
["Melodic human whistling harmonizing with natural birdsong", "", 10, 4.5, 50, 123],
["A parade marches through a town square, with drumbeats pounding, children clapping, and a horse neighing amidst the commotion", "", 10, 4.5, 50, 123],
["Quiet speech and then and airplane flying away", "", 10, 4.5, 50, 123],
["Battlefield scene, continuous roar of artillery and gunfire, high fidelity, the sharp crack of bullets, the thundering explosions of bombs", "", 10, 4.5, 50, 123],
],
cache_examples="lazy",
)
if __name__ == "__main__":
ensure_models_downloaded()
load_model_cache()
gr_interface.queue(max_size=15).launch()