SmallO commited on
Commit
cba6ba1
·
1 Parent(s): 0c88ebf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -23
app.py CHANGED
@@ -1,39 +1,50 @@
1
- # -*- coding: utf-8 -*-
 
 
 
 
2
 
3
- #!pip install torch
4
 
5
- #!pip install diffusers transformers accelerate torch
 
 
6
 
7
- #!pip install streamlit
 
8
 
9
- from diffusers import StableDiffusionPipeline
10
- import torch
11
- import streamlit as st
 
 
 
 
12
 
13
- def pic_mo(prom):
14
- model_id = "runwayml/stable-diffusion-v1-5"
15
- pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float32) # Change to float32
16
- # No need ("cuda") if running on CPU because dont't have money
17
 
18
- prompt = str(prom)
 
19
 
20
- image = pipe(prompt).images[0]
21
-
22
- return image
23
 
24
 
25
  def main():
26
- st.title("Text to Image with runwayml/stable-diffusion-v1-5")
27
 
28
  # Input text prompt
29
- title = st.text_input('Write a prompt (จะใช้เวลาค่อนข้างมากในการสร้างภาพเนื่องจากใช้ CPU ในการรันโมเดล)', "")
30
 
31
- if st.button('Generate Image'):
32
- # Call the pic_mo function to generate an image
33
- generated_image = pic_mo(title)
 
34
 
35
- # Display the generated image
36
- st.image(generated_image, caption='Generated Image', use_column_width=True)
 
 
37
 
38
- if __name__ == '__main__':
39
  main()
 
1
+ from transformers import AutoProcessor, MusicgenForConditionalGeneration
2
+ from IPython.display import Audio
3
+ import scipy
4
+ import torch
5
+ import streamlit as st
6
 
 
7
 
8
+ def mu_gen(prompt):
9
+ processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
10
+ model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
11
 
12
+ device = torch.device("cpu")
13
+ model.to(device)
14
 
15
+ inputs = processor(
16
+ text = [str(prompt)], # This line is correct
17
+ padding=True,
18
+ return_tensors="pt",
19
+ )
20
+
21
+ inputs = {key: value.to(device) for key, value in inputs.items()}
22
 
23
+ # Generate audio on CPU
24
+ audio_values = model.generate(**inputs, max_new_tokens=256)
25
+ sampling_rate = model.config.audio_encoder.sampling_rate
 
26
 
27
+ # Create an Audio object from the generated audio
28
+ result = Audio(audio_values[0].numpy(), rate=sampling_rate)
29
 
30
+ return result
 
 
31
 
32
 
33
  def main():
34
+ st.title("Text to Music Generator")
35
 
36
  # Input text prompt
37
+ prompt = st.text_input("Enter a text prompt", "")
38
 
39
+ if st.button("Generate Music"):
40
+ if prompt:
41
+ # Call the mu_gen function to generate music
42
+ generated_music = mu_gen(prompt)
43
 
44
+ # Display the generated audio
45
+ st.audio(generated_music, format="audio/wav")
46
+ else:
47
+ st.warning("Please enter a text prompt.")
48
 
49
+ if __name__ == "__main__":
50
  main()