Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import numpy as np | |
| import torch | |
| import transformers | |
| from packaging.version import parse | |
| import sys | |
| import io | |
| import importlib.metadata as importlib_metadata | |
| import soundfile as sf | |
| import importlib.metadata as importlib_metadata | |
| loading_kwargs = {} | |
| if parse(importlib_metadata.version("transformers")) >= parse("4.40.0"): | |
| loading_kwargs["attn_implementation"] = "eager" | |
| def generate(prompt): | |
| model = transformers.MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small", torchscript=True, return_dict=False, **loading_kwargs) | |
| sample_length = 8 | |
| n_tokens = sample_length * model.config.audio_encoder.frame_rate + 3 | |
| sampling_rate = model.config.audio_encoder.sampling_rate | |
| processor = transformers.AutoProcessor.from_pretrained("facebook/musicgen-small") | |
| inputs = processor( | |
| text=[ | |
| prompt, | |
| ], | |
| padding=True, | |
| return_tensors="pt", | |
| ) | |
| audio_values = model.generate(**inputs, do_sample=True, guidance_scale=3, max_new_tokens=n_tokens) | |
| waveform = audio_values[0].cpu().squeeze() * 2**15 | |
| audio_buffer = io.BytesIO() | |
| sf.write(audio_buffer, waveform.numpy().astype(np.int16), sampling_rate, format='WAV') | |
| audio_buffer.seek(0) | |
| return audio_buffer | |
| st.title("Music Generator") | |
| st.subheader("Select an example or write a text prompt") | |
| text_prompt = st.text_input("Text Prompt", "") | |
| examples = [ | |
| "80s pop track with bassy drums and synth", | |
| "Earthy tones, environmentally conscious, ukulele-infused, harmonic, breezy, easygoing, organic instrumentation, gentle grooves", | |
| "90s rock song with loud guitars and heavy drums", | |
| "Heartful EDM with beautiful synths and chords", | |
| "None" | |
| ] | |
| st.subheader("Examples") | |
| selected_example = st.radio("Select an example", examples) | |
| if st.button("Generate Audio"): | |
| if selected_example != "None" or text_prompt: | |
| prompt = "" | |
| if text_prompt: | |
| prompt = text_prompt | |
| else: | |
| prompt = selected_example | |
| with st.spinner("Generating audio..."): | |
| audio_output = generate(prompt) | |
| st.audio(audio_output, format='audio/wav') | |
| else: | |
| st.warning("Please select or enter a text prompt.") | |
| if st.checkbox("Show debug info"): | |
| if text_prompt: | |
| st.write("Prompt:", text_prompt) | |
| else: | |
| st.write("Prompt:", selected_example) |