File size: 1,272 Bytes
cba6ba1
 
 
79083e5
cba6ba1
105d116
 
79083e5
105d116
 
79083e5
105d116
 
 
 
 
79083e5
105d116
79083e5
105d116
 
 
79083e5
105d116
14fb55e
79083e5
cba6ba1
79083e5
 
cba6ba1
79083e5
cba6ba1
 
 
105d116
79083e5
cba6ba1
105d116
cba6ba1
 
79083e5
cba6ba1
105d116
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
from transformers import AutoProcessor, MusicgenForConditionalGeneration
import torch
import streamlit as st

def mu_gen(prompt):
    processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
    model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")

    device = torch.device("cpu")
    model.to(device)

    inputs = processor(
        text=[str(prompt)],
        padding=True,
        return_tensors="pt",
    )

    inputs = {key: value.to(device) for key, value in inputs.items()}

    # Generate audio on CPU
    audio_values = model.generate(**inputs, max_new_tokens=256)
    sampling_rate = model.config.audio_encoder.sampling_rate

    return audio_values, sampling_rate

def main():
    st.title("Text to Music Generator")

    # Input text prompt
    prompt = st.text_input("Enter a text prompt", "")

    if st.button("Generate Music"):
        if prompt:
            # Call the mu_gen function to generate music
            generated_music, sampling_rate = mu_gen(prompt)

            # Display the generated audio
            st.audio(generated_music[0].numpy(), format="audio/wav", sample_rate=sampling_rate)
        else:
            st.warning("Please enter a text prompt.")

if __name__ == "__main__":
    main()