SmallO commited on
Commit
ec4522b
·
1 Parent(s): 9d4aea3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -23
app.py CHANGED
@@ -1,47 +1,51 @@
 
 
1
  from transformers import AutoProcessor, MusicgenForConditionalGeneration
2
  from IPython.display import Audio
3
- import scipy
4
  import torch
5
  import streamlit as st
6
 
7
-
8
  def mu_gen(prompt):
9
- processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
10
- model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
 
 
 
11
 
12
- device = torch.device("cpu")
13
- model.to(device)
 
 
 
14
 
15
- inputs = processor(
16
- text = [str(prompt)], # This line is correct
17
- padding=True,
18
- return_tensors="pt",
19
- )
20
-
21
- inputs = {key: value.to(device) for key, value in inputs.items()}
22
 
23
  # Generate audio on CPU
24
- audio_values = model.generate(**inputs, max_new_tokens=256)
25
- sampling_rate = model.config.audio_encoder.sampling_rate
26
 
27
- # Create an Audio object from the generated audio
28
- result = Audio(audio_values[0].numpy(), rate=sampling_rate)
29
 
30
- return result
 
 
 
31
 
 
32
 
33
  def main():
34
- st.title("Text to music")
35
 
36
  # Input text prompt
37
  title = st.text_input('Write a prompt (จะใช้เวลาค่อนข้างมากในการสร้างเนื่องจากใช้ CPU ในการรันโมเดล)', "")
38
 
39
- if st.button('Generate Image'):
40
  # Call the mu_gen function to generate music using the 'title' prompt
41
- generated_music = mu_gen(title) # Replace 'prompt' with 'title'
42
 
43
- # Display the generated audio
44
- st.audio(generated_music, format='audio/mpeg', start_time=0)
45
 
46
  if __name__ == '__main__':
47
  main()
 
1
+ import scipy.io.wavfile as wav
2
+ import io
3
  from transformers import AutoProcessor, MusicgenForConditionalGeneration
4
  from IPython.display import Audio
 
5
  import torch
6
  import streamlit as st
7
 
 
8
  def mu_gen(prompt):
9
+ processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
10
+ model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
11
+
12
+ device = torch.device("cpu")
13
+ model.to(device)
14
 
15
+ inputs = processor(
16
+ text=[str(prompt)], # This line is correct
17
+ padding=True,
18
+ return_tensors="pt",
19
+ )
20
 
21
+ inputs = {key: value.to(device) for key, value in inputs.items()}
 
 
 
 
 
 
22
 
23
  # Generate audio on CPU
24
+ audio_values = model.generate(**inputs, max_new_tokens=256)
25
+ sampling_rate = model.config.audio_encoder.sampling_rate
26
 
27
+ # Convert audio data to WAV format
28
+ wav_data = audio_values[0].numpy().tobytes()
29
 
30
+ # Create an in-memory WAV file
31
+ with io.BytesIO() as wav_file:
32
+ wav.write(wav_file, sampling_rate, wav_data)
33
+ wav_bytes = wav_file.getvalue()
34
 
35
+ return wav_bytes # Return the WAV audio data as bytes
36
 
37
  def main():
38
+ st.title("Text to Music")
39
 
40
  # Input text prompt
41
  title = st.text_input('Write a prompt (จะใช้เวลาค่อนข้างมากในการสร้างเนื่องจากใช้ CPU ในการรันโมเดล)', "")
42
 
43
+ if st.button('Generate Music'):
44
  # Call the mu_gen function to generate music using the 'title' prompt
45
+ generated_music = mu_gen(title)
46
 
47
+ # Display the generated audio in WAV format
48
+ st.audio(generated_music, format='audio/wav', start_time=0)
49
 
50
  if __name__ == '__main__':
51
  main()