|
|
import torchaudio |
|
|
from audiocraft.models import AudioGen |
|
|
from audiocraft.data.audio import audio_write |
|
|
import os |
|
|
import gradio as gr |
|
|
|
|
|
def generate_audio(descriptions): |
|
|
if not os.path.exists('audio_files'): |
|
|
os.makedirs('audio_files') |
|
|
model = AudioGen.get_pretrained('facebook/audiogen-medium') |
|
|
model.set_generation_params(duration=5) |
|
|
wav = model.generate([descriptions]) |
|
|
results = [] |
|
|
|
|
|
for idx, one_wav in enumerate(wav): |
|
|
filename = f'{descriptions}.wav' |
|
|
file_path = os.path.join('audio_files', filename) |
|
|
audio_write(file_path, one_wav.cpu(), model.sample_rate, strategy="loudness", loudness_compressor=True, add_suffix=False) |
|
|
print(f"Generated audio for '{descriptions}'") |
|
|
results.append(file_path) |
|
|
|
|
|
return results[0] |
|
|
|
|
|
def ui_full(): |
|
|
with gr. Blocks() as interface: |
|
|
gr.Markdown( |
|
|
""" |
|
|
# AudioGen Demo |
|
|
presented at: [Simple and Controllable Music Generation](https://huggingface.co/) |
|
|
""" |
|
|
) |
|
|
with gr.Row(): |
|
|
descriptions = gr.Textbox(lines=2, label="Enter descriptions of the audio to generate") |
|
|
with gr.Row(): |
|
|
generate_button = gr.Button("Generate Audio") |
|
|
with gr.Row(): |
|
|
output = gr.Audio(label="Generated Audio") |
|
|
|
|
|
generate_button.click(fn=generate_audio, inputs=descriptions, outputs=[output]) |
|
|
interface.queue().launch() |
|
|
|
|
|
ui_full() |