| | import gradio as gr |
| | import torch |
| | import numpy as np |
| | from transformers import MusicgenForConditionalGeneration, AutoProcessor |
| | import scipy.io.wavfile |
| |
|
| | def generate_music(prompt, unconditional=False): |
| | model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small") |
| | device = "cuda:0" if torch.cuda.is_available() else "cpu" |
| | model.to(device) |
| |
|
| | |
| | if unconditional: |
| | unconditional_inputs = model.get_unconditional_inputs(num_samples=1) |
| | audio_values = model.generate(**unconditional_inputs, do_sample=True, max_new_tokens=256) |
| | else: |
| | processor = AutoProcessor.from_pretrained("facebook/musicgen-small") |
| | inputs = processor(text=prompt, padding=True, return_tensors="pt") |
| | audio_values = model.generate(**inputs.to(device), do_sample=True, guidance_scale=3, max_new_tokens=256) |
| |
|
| | sampling_rate = model.config.audio_encoder.sampling_rate |
| | audio_file = "musicgen_out.wav" |
| | |
| | |
| | audio_data = audio_values[0].cpu().numpy() |
| | |
| | |
| | if audio_data.ndim > 1: |
| | audio_data = audio_data[0] |
| |
|
| | |
| | audio_data = np.clip(audio_data, -1.0, 1.0) |
| | audio_data = (audio_data * 32767).astype(np.int16) |
| |
|
| | |
| | scipy.io.wavfile.write(audio_file, sampling_rate, audio_data) |
| | |
| | return audio_file |
| |
|
| | def interface(prompt, unconditional): |
| | audio_file = generate_music(prompt, unconditional) |
| | return audio_file |
| |
|
| | with gr.Blocks() as demo: |
| | gr.Markdown("# AI-Powered Music Generation") |
| | |
| | with gr.Row(): |
| | prompt_input = gr.Textbox(label="Enter the Music Prompt") |
| | unconditional_checkbox = gr.Checkbox(label="Generate Unconditional Music") |
| |
|
| | generate_button = gr.Button("Generate Music") |
| | output_audio = gr.Audio(label="Output Music") |
| |
|
| | generate_button.click( |
| | interface, |
| | inputs=[prompt_input, unconditional_checkbox], |
| | outputs=output_audio, |
| | show_progress=True |
| | ) |
| |
|
| | demo.launch(share=True) |