| import gradio as gr |
| import torchaudio |
| from audiocraft.models import MAGNeT |
| from audiocraft. data. audio import audio_write |
|
|
| model = MAGNeT.get_pretrained('facebook/magnet-small-10secs') |
|
|
|
|
| def infer(description): |
| descriptions = ['disco beat', 'energetic EDM'] |
|
|
| wav = model.generate(descriptions) |
| |
| for idx, one_wav in enumerate(wav): |
| print(idx) |
| audio_write(f'{idx}', |
| one_wav.cpu(), |
| model.sample_rate, |
| strategy="loudness", |
| loudness_compressor=True) |
| |
| return "done" |
|
|
| gr.Interface( |
| fn = infer, |
| inputs = gr.Textbox(value="gogo"), |
| outputs = gr.Textbox() |
| ).launch() |