|
|
import io |
|
|
from transformers import AutoProcessor, MusicgenForConditionalGeneration |
|
|
from IPython.display import Audio |
|
|
import torch |
|
|
import streamlit as st |
|
|
import wave |
|
|
|
|
|
def mu_gen(prompt): |
|
|
processor = AutoProcessor.from_pretrained("facebook/musicgen-small") |
|
|
model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small") |
|
|
|
|
|
device = torch.device("cpu") |
|
|
model.to(device) |
|
|
|
|
|
inputs = processor( |
|
|
text=[str(prompt)], |
|
|
padding=True, |
|
|
return_tensors="pt", |
|
|
) |
|
|
|
|
|
inputs = {key: value.to(device) for key, value in inputs.items()} |
|
|
|
|
|
|
|
|
audio_values = model.generate(**inputs, max_new_tokens=256) |
|
|
sampling_rate = model.config.audio_encoder.sampling_rate |
|
|
|
|
|
|
|
|
wav_data = audio_values[0].numpy() |
|
|
|
|
|
|
|
|
with io.BytesIO() as wav_file: |
|
|
with wave.open(wav_file, 'wb') as wf: |
|
|
wf.setnchannels(1) |
|
|
wf.setsampwidth(2) |
|
|
wf.setframerate(sampling_rate) |
|
|
wf.writeframes(wav_data.tobytes()) |
|
|
wav_bytes = wav_file.getvalue() |
|
|
|
|
|
return wav_bytes |
|
|
|
|
|
def main(): |
|
|
st.title("Text to Music") |
|
|
|
|
|
|
|
|
title = st.text_input('Write a prompt (จะใช้เวลาค่อนข้างมากในการสร้างเนื่องจากใช้ CPU ในการรันโมเดล)', "") |
|
|
|
|
|
if st.button('Generate Music'): |
|
|
|
|
|
generated_music = mu_gen(title) |
|
|
|
|
|
|
|
|
st.audio(generated_music, format='audio/wav', start_time=0) |
|
|
|
|
|
if __name__ == '__main__': |
|
|
main() |
|
|
|