|
|
import spaces |
|
|
import os |
|
|
import torchaudio |
|
|
from audiocraft.models import AudioGen |
|
|
from audiocraft.data.audio import audio_write |
|
|
import gradio as gr |
|
|
|
|
|
OUTPUT_DIR = "audio_files" |
|
|
DEFAULT_DURATION = 5 |
|
|
os.makedirs(OUTPUT_DIR, exist_ok=True) |
|
|
|
|
|
model = None |
|
|
|
|
|
@spaces.GPU |
|
|
def generate_audio(descriptions: str, duration: int = DEFAULT_DURATION): |
|
|
global model |
|
|
if model is None: |
|
|
model = AudioGen.get_pretrained('facebook/audiogen-medium') |
|
|
|
|
|
if not os.path.exists('audio_files'): |
|
|
os.makedirs('audio_files') |
|
|
model.set_generation_params(duration=duration) |
|
|
wav = model.generate([descriptions]) |
|
|
results = [] |
|
|
print(f"Received call: '{descriptions}' duration={duration}") |
|
|
for one_wav in wav: |
|
|
filename = f'{descriptions}.wav' |
|
|
file_path = os.path.join('audio_files', filename) |
|
|
audio_write(file_path, one_wav.cpu(), model.sample_rate, strategy="loudness", loudness_compressor=True, add_suffix=False) |
|
|
print(f"Generated audio for '{descriptions}'") |
|
|
results.append(file_path) |
|
|
|
|
|
return results[0] |
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
gr.Markdown("# AudioGen Demo") |
|
|
with gr.Row(): |
|
|
descriptions = gr.Textbox(lines=1, label="Enter a description of the audio") |
|
|
duration_slider = gr.Slider(minimum=1, maximum=30, value=DEFAULT_DURATION, step=1, label="Duration (seconds)") |
|
|
with gr.Row(): |
|
|
generate_button = gr.Button("Generate Audio") |
|
|
with gr.Row(): |
|
|
output = gr.Audio(label="Generated Audio") |
|
|
|
|
|
generate_button.click(fn=generate_audio, inputs=[descriptions, duration_slider], outputs=output) |
|
|
|
|
|
demo = demo.queue(max_size=10, status_update_rate="auto") |
|
|
demo.launch(share=True) |