Khaled21 commited on
Commit
c0a5970
·
1 Parent(s): 4969333

Updating main

Browse files
Files changed (1) hide show
  1. app.py +40 -7
app.py CHANGED
@@ -1,13 +1,46 @@
1
  import streamlit as st
2
  from transformers import pipeline
 
 
3
 
4
- model_name = "SamLowe/roberta-base-go_emotions"
5
 
6
- pipe = pipeline(task="text-classification", model=model_name)
 
 
 
 
7
 
8
- user_text = st.text_area("Explain your feelings here")
9
 
10
- if user_text:
11
- emotion = pipe(user_text)
12
- st.markdown("You are feeling like ")
13
- st.json(emotion)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
  from transformers import pipeline
3
+ import scipy.io.wavfile
4
+ import tempfile
5
 
 
6
 
7
+ # -----------------------------------------------------------
8
+ def load_css(file_path):
9
+ with open(file_path) as f:
10
+ css_to_load = f.read()
11
+ st.markdown(f"<style>{css_to_load}</style>", unsafe_allow_html=True)
12
 
 
13
 
14
+ # -----------------------------------------------------------
15
+ def load_model():
16
+ return pipeline("text-to-audio", model="facebook/musicgen-small")
17
+
18
+
19
+ # -----------------------------------------------------------
20
+ def main():
21
+ try:
22
+ st.set_page_config(page_title="Text to Music", page_icon="Music")
23
+ load_css("styles/styles.css")
24
+ st.title("Turn Your Text Into Music")
25
+
26
+ synthesizer = load_model()
27
+ user_text = st.text_area("What type of beat would you like to hear?")
28
+
29
+ if user_text:
30
+ music = synthesizer(user_text, forward_params={"do_sample": True})
31
+
32
+ with tempfile.NamedTemporaryFile(delete=True, suffix=".wav") as tmp:
33
+ print(f"Temporary file created at: {tmp.name}")
34
+ scipy.io.wavfile.write(
35
+ tmp.name, rate=music["sampling_rate"], data=music["audio"]
36
+ )
37
+ st.audio(tmp.name, format="audio/wav", start_time=0)
38
+
39
+ except Exception as e:
40
+ st.error(e)
41
+
42
+
43
+ # -----------------------------------------------------------
44
+
45
+ if __name__ == "__main__":
46
+ main()