arjunanand13 commited on
Commit
2b6555e
·
verified ·
1 Parent(s): 4b0bb7b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -118
app.py CHANGED
@@ -1,24 +1,22 @@
1
  import gradio as gr
2
  import requests
3
  from transformers import pipeline
4
- import edge_tts
5
  import tempfile
6
  import asyncio
7
  import os
8
  import json
9
 
10
- ENDPOINT_URL = "https://l8opkfvazwgxqljm.us-east-1.aws.endpoints.huggingface.cloud"
11
  hf_token = os.getenv("HF_TOKEN")
12
 
13
- asr = pipeline("automatic-speech-recognition", "facebook/wav2vec2-base-960h")
 
 
 
 
14
 
15
  INITIAL_MESSAGE = "Hi! I'm your music buddy—tell me about your mood and the type of tunes you're in the mood for today!"
16
 
17
- def speech_to_text(speech):
18
- if speech is None:
19
- return ""
20
- return asr(speech)["text"]
21
-
22
  def classify_mood(input_string):
23
  input_string = input_string.lower()
24
  mood_words = {"happy", "sad", "instrumental", "party"}
@@ -27,9 +25,9 @@ def classify_mood(input_string):
27
  return word, True
28
  return None, False
29
 
30
- def generate(prompt, history, temperature=0.1, max_new_tokens=2048):
31
  if not hf_token:
32
- return "Error: Hugging Face authentication required. Please set your HF_TOKEN."
33
 
34
  formatted_prompt = format_prompt(prompt, history)
35
 
@@ -41,8 +39,8 @@ def generate(prompt, history, temperature=0.1, max_new_tokens=2048):
41
  payload = {
42
  "model": "meta-llama/Llama-3.1-8B-Instruct",
43
  "messages": [{"role": "user", "content": formatted_prompt}],
44
- "temperature": temperature,
45
- "max_tokens": max_new_tokens,
46
  "stream": False
47
  }
48
 
@@ -55,136 +53,93 @@ def generate(prompt, history, temperature=0.1, max_new_tokens=2048):
55
 
56
  mood, is_classified = classify_mood(output)
57
  if is_classified:
58
- playlist_message = f"Playing {mood.capitalize()} playlist for you!"
59
- return playlist_message
60
  return output
61
  else:
62
  return f"Error: {response.status_code} - {response.text}"
63
 
64
  except Exception as e:
65
- return f"Error generating response: {str(e)}"
66
 
67
  def format_prompt(message, history):
68
- fixed_prompt = """
69
- You are a smart mood analyzer tasked with determining the user's mood for a music recommendation system. Your goal is to classify the user's mood into one of four categories: Happy, Sad, Instrumental, or Party.
70
- Instructions:
71
- 1. Engage in a conversation with the user to understand their mood.
72
- 2. Ask relevant questions to guide the conversation towards mood classification.
73
- 3. If the user's mood is clear, respond with a single word: "Happy", "Sad", "Instrumental", or "Party".
74
- 4. If the mood is unclear, continue the conversation with a follow-up question.
75
- 5. Limit the conversation to a maximum of 5 exchanges.
76
- 6. Do not classify the mood prematurely if it's not evident from the user's responses.
77
- 7. Focus on the user's emotional state rather than specific activities or preferences.
78
- 8. If unable to classify after 5 exchanges, respond with "Unclear" to indicate the need for more information.
79
- Remember: Your primary goal is mood classification. Stay on topic and guide the conversation towards understanding the user's emotional state.
80
- """
81
- prompt = f"{fixed_prompt}\n"
82
 
83
- for i, (user_prompt, bot_response) in enumerate(history):
84
- prompt += f"User: {user_prompt}\nAssistant: {bot_response}\n"
85
- if i == 3:
86
- prompt += "Note: This is the last exchange. Classify the mood if possible or respond with 'Unclear'.\n"
87
 
88
  prompt += f"User: {message}\nAssistant:"
89
  return prompt
90
 
91
- async def text_to_speech(text):
 
 
 
 
 
 
 
 
 
 
 
92
  try:
93
- communicate = edge_tts.Communicate(text)
94
- with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
95
- tmp_path = tmp_file.name
96
- await communicate.save(tmp_path)
97
- return tmp_path
98
  except Exception as e:
99
- print(f"TTS Error: {e}")
100
- return None
101
-
102
- def process_input(input_text, history):
103
- if not input_text:
104
- return history, history, ""
105
- response = generate(input_text, history)
106
- history.append((input_text, response))
107
- return history, history, ""
108
-
109
- async def generate_audio(history):
110
- if history and len(history) > 0:
111
- last_response = history[-1][1]
112
- audio_path = await text_to_speech(last_response)
113
- return audio_path
114
- return None
115
-
116
- async def init_chat():
117
- history = [("", INITIAL_MESSAGE)]
118
- audio_path = await text_to_speech(INITIAL_MESSAGE)
119
- return history, history, audio_path
120
-
121
- def handle_voice_upload(audio_file):
122
- if audio_file is None:
123
- return ""
124
- return speech_to_text(audio_file)
125
-
126
- with gr.Blocks() as demo:
127
- gr.Markdown("# Mood-Based Music Recommender with Voice Chat")
128
 
129
- chatbot = gr.Chatbot()
130
 
131
  with gr.Row():
132
  msg = gr.Textbox(
133
  placeholder="Type your message here...",
134
- label="Text Input",
135
  scale=4
136
  )
137
- submit = gr.Button("Send", scale=1)
138
 
139
- with gr.Row():
140
- voice_input = gr.File(
141
- label="Upload Voice Recording (or record using your device)",
142
- file_types=[".wav", ".mp3", ".m4a", ".ogg"]
 
143
  )
 
144
 
145
- audio_output = gr.Audio(label="AI Response", autoplay=True)
146
-
147
- state = gr.State([])
148
-
149
- demo.load(init_chat, outputs=[state, chatbot, audio_output])
150
-
151
- def submit_and_generate_audio(input_text, history):
152
- new_state, new_chatbot, empty_msg = process_input(input_text, history)
153
- return new_state, new_chatbot, empty_msg
154
-
155
- msg.submit(
156
- submit_and_generate_audio,
157
- inputs=[msg, state],
158
- outputs=[state, chatbot, msg]
159
- ).then(
160
- generate_audio,
161
- inputs=[state],
162
- outputs=[audio_output]
163
- )
164
 
165
- submit.click(
166
- submit_and_generate_audio,
167
- inputs=[msg, state],
168
- outputs=[state, chatbot, msg]
169
- ).then(
170
- generate_audio,
171
- inputs=[state],
172
- outputs=[audio_output]
173
- )
174
-
175
- voice_input.upload(
176
- handle_voice_upload,
177
- inputs=[voice_input],
178
- outputs=[msg]
179
- ).then(
180
- submit_and_generate_audio,
181
- inputs=[msg, state],
182
- outputs=[state, chatbot, msg]
183
- ).then(
184
- generate_audio,
185
- inputs=[state],
186
- outputs=[audio_output]
187
- )
188
 
189
  if __name__ == "__main__":
190
- demo.launch(share=True)
 
1
  import gradio as gr
2
  import requests
3
  from transformers import pipeline
 
4
  import tempfile
5
  import asyncio
6
  import os
7
  import json
8
 
9
+ ENDPOINT_URL = "https://xzup8268xrmmxcma.us-east-1.aws.endpoints.huggingface.cloud"
10
  hf_token = os.getenv("HF_TOKEN")
11
 
12
+ try:
13
+ asr = pipeline("automatic-speech-recognition", "facebook/wav2vec2-base-960h")
14
+ except:
15
+ asr = None
16
+ print("ASR model failed to load, voice features disabled")
17
 
18
  INITIAL_MESSAGE = "Hi! I'm your music buddy—tell me about your mood and the type of tunes you're in the mood for today!"
19
 
 
 
 
 
 
20
  def classify_mood(input_string):
21
  input_string = input_string.lower()
22
  mood_words = {"happy", "sad", "instrumental", "party"}
 
25
  return word, True
26
  return None, False
27
 
28
+ def generate(prompt, history):
29
  if not hf_token:
30
+ return "Error: Please set your HF_TOKEN environment variable."
31
 
32
  formatted_prompt = format_prompt(prompt, history)
33
 
 
39
  payload = {
40
  "model": "meta-llama/Llama-3.1-8B-Instruct",
41
  "messages": [{"role": "user", "content": formatted_prompt}],
42
+ "temperature": 0.1,
43
+ "max_tokens": 512,
44
  "stream": False
45
  }
46
 
 
53
 
54
  mood, is_classified = classify_mood(output)
55
  if is_classified:
56
+ return f"🎵 Playing {mood.capitalize()} playlist for you! 🎵"
 
57
  return output
58
  else:
59
  return f"Error: {response.status_code} - {response.text}"
60
 
61
  except Exception as e:
62
+ return f"Error: {str(e)}"
63
 
64
  def format_prompt(message, history):
65
+ prompt = """You are a mood analyzer for music recommendations. Classify user mood as: Happy, Sad, Instrumental, or Party.
66
+
67
+ Instructions:
68
+ 1. Chat with the user to understand their mood
69
+ 2. When clear, respond with ONLY one word: Happy, Sad, Instrumental, or Party
70
+ 3. If unclear, ask a follow-up question
71
+ 4. Maximum 5 exchanges
72
+
73
+ """
 
 
 
 
 
74
 
75
+ for user_msg, bot_msg in history:
76
+ if user_msg.strip():
77
+ prompt += f"User: {user_msg}\nAssistant: {bot_msg}\n"
 
78
 
79
  prompt += f"User: {message}\nAssistant:"
80
  return prompt
81
 
82
+ def chat_interface(message, history):
83
+ if not message.strip():
84
+ return history, ""
85
+
86
+ response = generate(message, history)
87
+ history.append([message, response])
88
+ return history, ""
89
+
90
+ def speech_to_text_simple(audio_file):
91
+ if not asr or not audio_file:
92
+ return "Voice recognition not available. Please type your message."
93
+
94
  try:
95
+ result = asr(audio_file)
96
+ return result["text"]
 
 
 
97
  except Exception as e:
98
+ return f"Voice processing error: {str(e)}"
99
+
100
+ css = """
101
+ .gradio-container {
102
+ max-width: 800px !important;
103
+ margin: auto !important;
104
+ }
105
+ """
106
+
107
+ with gr.Blocks(css=css, title="Music Mood Analyzer") as demo:
108
+ gr.Markdown("# 🎵 Music Mood Analyzer")
109
+ gr.Markdown("Tell me about your mood and I'll recommend the perfect playlist!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
 
111
+ chatbot = gr.Chatbot(height=400, label="Chat")
112
 
113
  with gr.Row():
114
  msg = gr.Textbox(
115
  placeholder="Type your message here...",
116
+ label="Message",
117
  scale=4
118
  )
119
+ send_btn = gr.Button("Send", scale=1, variant="primary")
120
 
121
+ if asr:
122
+ gr.Markdown("### 🎤 Voice Input (Optional)")
123
+ audio_input = gr.Audio(
124
+ label="Record your voice",
125
+ type="filepath"
126
  )
127
+ transcribe_btn = gr.Button("Convert Speech to Text")
128
 
129
+ transcribe_btn.click(
130
+ speech_to_text_simple,
131
+ inputs=[audio_input],
132
+ outputs=[msg]
133
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
 
135
+ def respond(message, history):
136
+ history, empty = chat_interface(message, history)
137
+ return history, empty
138
+
139
+ msg.submit(respond, [msg, chatbot], [chatbot, msg])
140
+ send_btn.click(respond, [msg, chatbot], [chatbot, msg])
141
+
142
+ demo.load(lambda: [[None, INITIAL_MESSAGE]], None, chatbot)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
 
144
  if __name__ == "__main__":
145
+ demo.launch(share=True, show_error=True)