arjunanand13 commited on
Commit
e3ee8f6
·
verified ·
1 Parent(s): 962d7db

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +99 -59
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import gradio as gr
2
  import requests
3
  from transformers import pipeline
 
4
  import tempfile
5
  import asyncio
6
  import os
@@ -9,14 +10,15 @@ import json
9
  ENDPOINT_URL = "https://xzup8268xrmmxcma.us-east-1.aws.endpoints.huggingface.cloud"
10
  hf_token = os.getenv("HF_TOKEN")
11
 
12
- try:
13
- asr = pipeline("automatic-speech-recognition", "facebook/wav2vec2-base-960h")
14
- except:
15
- asr = None
16
- print("ASR model failed to load, voice features disabled")
17
 
18
  INITIAL_MESSAGE = "Hi! I'm your music buddy—tell me about your mood and the type of tunes you're in the mood for today!"
19
 
 
 
 
 
 
20
  def classify_mood(input_string):
21
  input_string = input_string.lower()
22
  mood_words = {"happy", "sad", "instrumental", "party"}
@@ -25,9 +27,9 @@ def classify_mood(input_string):
25
  return word, True
26
  return None, False
27
 
28
- def generate(prompt, history):
29
  if not hf_token:
30
- return "Error: Please set your HF_TOKEN environment variable."
31
 
32
  formatted_prompt = format_prompt(prompt, history)
33
 
@@ -39,8 +41,8 @@ def generate(prompt, history):
39
  payload = {
40
  "model": "meta-llama/Llama-3.1-8B-Instruct",
41
  "messages": [{"role": "user", "content": formatted_prompt}],
42
- "temperature": 0.1,
43
- "max_tokens": 512,
44
  "stream": False
45
  }
46
 
@@ -53,13 +55,14 @@ def generate(prompt, history):
53
 
54
  mood, is_classified = classify_mood(output)
55
  if is_classified:
56
- return f"🎵 Playing {mood.capitalize()} playlist for you! 🎵"
 
57
  return output
58
  else:
59
  return f"Error: {response.status_code} - {response.text}"
60
 
61
  except Exception as e:
62
- return f"Error: {str(e)}"
63
 
64
  def format_prompt(message, history):
65
  fixed_prompt = """
@@ -85,67 +88,104 @@ def format_prompt(message, history):
85
  prompt += f"User: {message}\nAssistant:"
86
  return prompt
87
 
88
- def chat_interface(message, history):
89
- if not message.strip():
90
- return history, ""
91
-
92
- response = generate(message, history)
93
- history.append([message, response])
94
- return history, ""
95
-
96
- def speech_to_text_simple(audio_file):
97
- if not asr or not audio_file:
98
- return "Voice recognition not available. Please type your message."
99
-
100
  try:
101
- result = asr(audio_file)
102
- return result["text"]
 
 
 
103
  except Exception as e:
104
- return f"Voice processing error: {str(e)}"
105
-
106
- css = """
107
- .gradio-container {
108
- max-width: 800px !important;
109
- margin: auto !important;
110
- }
111
- """
112
-
113
- with gr.Blocks(css=css, title="Music Mood Analyzer") as demo:
114
- gr.Markdown("# 🎵 Music Mood Analyzer")
115
- gr.Markdown("Tell me about your mood and I'll recommend the perfect playlist!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
 
117
- chatbot = gr.Chatbot(height=400, label="Chat")
118
 
119
  with gr.Row():
120
  msg = gr.Textbox(
121
  placeholder="Type your message here...",
122
- label="Message",
123
  scale=4
124
  )
125
- send_btn = gr.Button("Send", scale=1, variant="primary")
126
 
127
- if asr:
128
- gr.Markdown("### 🎤 Voice Input (Optional)")
129
- audio_input = gr.Audio(
130
- label="Record your voice",
131
  type="filepath"
132
  )
133
- transcribe_btn = gr.Button("Convert Speech to Text")
134
 
135
- transcribe_btn.click(
136
- speech_to_text_simple,
137
- inputs=[audio_input],
138
- outputs=[msg]
139
- )
140
-
141
- def respond(message, history):
142
- history, empty = chat_interface(message, history)
143
- return history, empty
144
-
145
- msg.submit(respond, [msg, chatbot], [chatbot, msg])
146
- send_btn.click(respond, [msg, chatbot], [chatbot, msg])
 
 
 
 
 
 
 
147
 
148
- demo.load(lambda: [[None, INITIAL_MESSAGE]], None, chatbot)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
 
150
  if __name__ == "__main__":
151
- demo.launch(share=True, show_error=True)
 
1
  import gradio as gr
2
  import requests
3
  from transformers import pipeline
4
+ import edge_tts
5
  import tempfile
6
  import asyncio
7
  import os
 
10
  ENDPOINT_URL = "https://xzup8268xrmmxcma.us-east-1.aws.endpoints.huggingface.cloud"
11
  hf_token = os.getenv("HF_TOKEN")
12
 
13
+ asr = pipeline("automatic-speech-recognition", "facebook/wav2vec2-base-960h")
 
 
 
 
14
 
15
  INITIAL_MESSAGE = "Hi! I'm your music buddy—tell me about your mood and the type of tunes you're in the mood for today!"
16
 
17
+ def speech_to_text(speech):
18
+ if speech is None:
19
+ return ""
20
+ return asr(speech)["text"]
21
+
22
  def classify_mood(input_string):
23
  input_string = input_string.lower()
24
  mood_words = {"happy", "sad", "instrumental", "party"}
 
27
  return word, True
28
  return None, False
29
 
30
+ def generate(prompt, history, temperature=0.1, max_new_tokens=2048):
31
  if not hf_token:
32
+ return "Error: Hugging Face authentication required. Please set your HF_TOKEN."
33
 
34
  formatted_prompt = format_prompt(prompt, history)
35
 
 
41
  payload = {
42
  "model": "meta-llama/Llama-3.1-8B-Instruct",
43
  "messages": [{"role": "user", "content": formatted_prompt}],
44
+ "temperature": temperature,
45
+ "max_tokens": max_new_tokens,
46
  "stream": False
47
  }
48
 
 
55
 
56
  mood, is_classified = classify_mood(output)
57
  if is_classified:
58
+ playlist_message = f"Playing {mood.capitalize()} playlist for you!"
59
+ return playlist_message
60
  return output
61
  else:
62
  return f"Error: {response.status_code} - {response.text}"
63
 
64
  except Exception as e:
65
+ return f"Error generating response: {str(e)}"
66
 
67
  def format_prompt(message, history):
68
  fixed_prompt = """
 
88
  prompt += f"User: {message}\nAssistant:"
89
  return prompt
90
 
91
+ async def text_to_speech(text):
 
 
 
 
 
 
 
 
 
 
 
92
  try:
93
+ communicate = edge_tts.Communicate(text)
94
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
95
+ tmp_path = tmp_file.name
96
+ await communicate.save(tmp_path)
97
+ return tmp_path
98
  except Exception as e:
99
+ print(f"TTS Error: {e}")
100
+ return None
101
+
102
+ def process_input(input_text, history):
103
+ if not input_text:
104
+ return history, history, ""
105
+ response = generate(input_text, history)
106
+ history.append((input_text, response))
107
+ return history, history, ""
108
+
109
+ async def generate_audio(history):
110
+ if history and len(history) > 0:
111
+ last_response = history[-1][1]
112
+ audio_path = await text_to_speech(last_response)
113
+ return audio_path
114
+ return None
115
+
116
+ async def init_chat():
117
+ history = [("", INITIAL_MESSAGE)]
118
+ audio_path = await text_to_speech(INITIAL_MESSAGE)
119
+ return history, history, audio_path
120
+
121
+ def handle_voice_upload(audio_file):
122
+ if audio_file is None:
123
+ return ""
124
+ return speech_to_text(audio_file)
125
+
126
+ with gr.Blocks() as demo:
127
+ gr.Markdown("# Mood-Based Music Recommender with Continuous Voice Chat")
128
 
129
+ chatbot = gr.Chatbot()
130
 
131
  with gr.Row():
132
  msg = gr.Textbox(
133
  placeholder="Type your message here...",
134
+ label="Text Input",
135
  scale=4
136
  )
137
+ submit = gr.Button("Send", scale=1)
138
 
139
+ with gr.Row():
140
+ voice_input = gr.Audio(
141
+ label="🎤 Record your voice or upload audio file",
142
+ sources=["microphone", "upload"],
143
  type="filepath"
144
  )
 
145
 
146
+ audio_output = gr.Audio(label="AI Response", autoplay=True)
147
+
148
+ state = gr.State([])
149
+
150
+ demo.load(init_chat, outputs=[state, chatbot, audio_output])
151
+
152
+ def submit_and_generate_audio(input_text, history):
153
+ new_state, new_chatbot, empty_msg = process_input(input_text, history)
154
+ return new_state, new_chatbot, empty_msg
155
+
156
+ msg.submit(
157
+ submit_and_generate_audio,
158
+ inputs=[msg, state],
159
+ outputs=[state, chatbot, msg]
160
+ ).then(
161
+ generate_audio,
162
+ inputs=[state],
163
+ outputs=[audio_output]
164
+ )
165
 
166
+ submit.click(
167
+ submit_and_generate_audio,
168
+ inputs=[msg, state],
169
+ outputs=[state, chatbot, msg]
170
+ ).then(
171
+ generate_audio,
172
+ inputs=[state],
173
+ outputs=[audio_output]
174
+ )
175
+
176
+ voice_input.upload(
177
+ handle_voice_upload,
178
+ inputs=[voice_input],
179
+ outputs=[msg]
180
+ ).then(
181
+ submit_and_generate_audio,
182
+ inputs=[msg, state],
183
+ outputs=[state, chatbot, msg]
184
+ ).then(
185
+ generate_audio,
186
+ inputs=[state],
187
+ outputs=[audio_output]
188
+ )
189
 
190
  if __name__ == "__main__":
191
+ demo.launch(share=True)