| import gradio as gr | |
| import requests | |
| from transformers import pipeline | |
| import edge_tts | |
| import tempfile | |
| import asyncio | |
| import os | |
| import json | |
| ENDPOINT_URL = "https://l8opkfvazwgxqljm.us-east-1.aws.endpoints.huggingface.cloud" | |
| hf_token = os.getenv("HF_TOKEN") | |
| asr = pipeline("automatic-speech-recognition", "facebook/wav2vec2-base-960h") | |
| INITIAL_MESSAGE = "Hi! I'm your music buddy—tell me about your mood and the type of tunes you're in the mood for today!" | |
| def speech_to_text(speech): | |
| if speech is None: | |
| return "" | |
| return asr(speech)["text"] | |
| def classify_mood(input_string): | |
| input_string = input_string.lower() | |
| mood_words = {"happy", "sad", "instrumental", "party"} | |
| for word in mood_words: | |
| if word in input_string: | |
| return word, True | |
| return None, False | |
| def generate(prompt, history, temperature=0.1, max_new_tokens=2048): | |
| if not hf_token: | |
| return "Error: Hugging Face authentication required. Please set your HF_TOKEN." | |
| formatted_prompt = format_prompt(prompt, history) | |
| headers = { | |
| "Authorization": f"Bearer {hf_token}", | |
| "Content-Type": "application/json" | |
| } | |
| payload = { | |
| "model": "meta-llama/Llama-3.1-8B-Instruct", | |
| "messages": [{"role": "user", "content": formatted_prompt}], | |
| "temperature": temperature, | |
| "max_tokens": max_new_tokens, | |
| "stream": False | |
| } | |
| try: | |
| response = requests.post(f"{ENDPOINT_URL}/v1/chat/completions", headers=headers, json=payload) | |
| if response.status_code == 200: | |
| result = response.json() | |
| output = result["choices"][0]["message"]["content"] | |
| mood, is_classified = classify_mood(output) | |
| if is_classified: | |
| playlist_message = f"Playing {mood.capitalize()} playlist for you!" | |
| return playlist_message | |
| return output | |
| else: | |
| return f"Error: {response.status_code} - {response.text}" | |
| except Exception as e: | |
| return f"Error generating response: {str(e)}" | |
| def format_prompt(message, history): | |
| fixed_prompt = """ | |
| You are a smart mood analyzer tasked with determining the user's mood for a music recommendation system. Your goal is to classify the user's mood into one of four categories: Happy, Sad, Instrumental, or Party. | |
| Instructions: | |
| 1. Engage in a conversation with the user to understand their mood. | |
| 2. Ask relevant questions to guide the conversation towards mood classification. | |
| 3. If the user's mood is clear, respond with a single word: "Happy", "Sad", "Instrumental", or "Party". | |
| 4. If the mood is unclear, continue the conversation with a follow-up question. | |
| 5. Limit the conversation to a maximum of 5 exchanges. | |
| 6. Do not classify the mood prematurely if it's not evident from the user's responses. | |
| 7. Focus on the user's emotional state rather than specific activities or preferences. | |
| 8. If unable to classify after 5 exchanges, respond with "Unclear" to indicate the need for more information. | |
| Remember: Your primary goal is mood classification. Stay on topic and guide the conversation towards understanding the user's emotional state. | |
| """ | |
| prompt = f"{fixed_prompt}\n" | |
| for i, (user_prompt, bot_response) in enumerate(history): | |
| prompt += f"User: {user_prompt}\nAssistant: {bot_response}\n" | |
| if i == 3: | |
| prompt += "Note: This is the last exchange. Classify the mood if possible or respond with 'Unclear'.\n" | |
| prompt += f"User: {message}\nAssistant:" | |
| return prompt | |
| async def text_to_speech(text): | |
| try: | |
| communicate = edge_tts.Communicate(text) | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file: | |
| tmp_path = tmp_file.name | |
| await communicate.save(tmp_path) | |
| return tmp_path | |
| except Exception as e: | |
| print(f"TTS Error: {e}") | |
| return None | |
| def process_input(input_text, history): | |
| if not input_text: | |
| return history, history, "" | |
| response = generate(input_text, history) | |
| history.append((input_text, response)) | |
| return history, history, "" | |
| async def generate_audio(history): | |
| if history and len(history) > 0: | |
| last_response = history[-1][1] | |
| audio_path = await text_to_speech(last_response) | |
| return audio_path | |
| return None | |
| async def init_chat(): | |
| history = [("", INITIAL_MESSAGE)] | |
| audio_path = await text_to_speech(INITIAL_MESSAGE) | |
| return history, history, audio_path | |
| def handle_voice_upload(audio_file): | |
| if audio_file is None: | |
| return "" | |
| return speech_to_text(audio_file) | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Mood-Based Music Recommender with Voice Chat") | |
| chatbot = gr.Chatbot() | |
| with gr.Row(): | |
| msg = gr.Textbox( | |
| placeholder="Type your message here...", | |
| label="Text Input", | |
| scale=4 | |
| ) | |
| submit = gr.Button("Send", scale=1) | |
| with gr.Row(): | |
| voice_input = gr.File( | |
| label="Upload Voice Recording (or record using your device)", | |
| file_types=[".wav", ".mp3", ".m4a", ".ogg"] | |
| ) | |
| audio_output = gr.Audio(label="AI Response", autoplay=True) | |
| state = gr.State([]) | |
| demo.load(init_chat, outputs=[state, chatbot, audio_output]) | |
| def submit_and_generate_audio(input_text, history): | |
| new_state, new_chatbot, empty_msg = process_input(input_text, history) | |
| return new_state, new_chatbot, empty_msg | |
| msg.submit( | |
| submit_and_generate_audio, | |
| inputs=[msg, state], | |
| outputs=[state, chatbot, msg] | |
| ).then( | |
| generate_audio, | |
| inputs=[state], | |
| outputs=[audio_output] | |
| ) | |
| submit.click( | |
| submit_and_generate_audio, | |
| inputs=[msg, state], | |
| outputs=[state, chatbot, msg] | |
| ).then( | |
| generate_audio, | |
| inputs=[state], | |
| outputs=[audio_output] | |
| ) | |
| voice_input.upload( | |
| handle_voice_upload, | |
| inputs=[voice_input], | |
| outputs=[msg] | |
| ).then( | |
| submit_and_generate_audio, | |
| inputs=[msg, state], | |
| outputs=[state, chatbot, msg] | |
| ).then( | |
| generate_audio, | |
| inputs=[state], | |
| outputs=[audio_output] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(share=True) |