| import os, time, requests, tempfile, asyncio, logging | |
| import gradio as gr | |
| from transformers import pipeline | |
| import edge_tts | |
| from collections import Counter | |
| # βββ Configuration ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| ENDPOINT_URL = "https://xzup8268xrmmxcma.us-east-1.aws.endpoints.huggingface.cloud/invocations" | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # βββ Helpers βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 1) SpeechβText | |
| asr = pipeline("automatic-speech-recognition", "facebook/wav2vec2-base-960h") | |
| def speech_to_text(audio): | |
| if not audio: | |
| return "" | |
| # Gradio supplies a tuple (sr, ndarray) | |
| if isinstance(audio, tuple): | |
| sr, arr = audio | |
| return asr(arr, sampling_rate=sr)["text"] | |
| # filepath | |
| return asr(audio)["text"] | |
| # 2) Prompt formatting | |
| def format_prompt(message, history): | |
| fixed_prompt = """ | |
| You are a smart mood analyzer tasked with determining the user's mood for a music recommendation system. Your goal is to classify the user's mood into one of four categories: Happy, Sad, Instrumental, or Party. | |
| Instructions: | |
| 1. Engage in a conversation with the user to understand their mood. | |
| 2. Ask relevant questions to guide the conversation towards mood classification. | |
| 3. If the user's mood is clear, respond with a single word: "Happy", "Sad", "Instrumental", or "Party". | |
| 4. If the mood is unclear, continue the conversation with a follow-up question. | |
| 5. Limit the conversation to a maximum of 5 exchanges. | |
| 6. Do not classify the mood prematurely if it's not evident from the user's responses. | |
| 7. Focus on the user's emotional state rather than specific activities or preferences. | |
| 8. If unable to classify after 5 exchanges, respond with "Unclear" to indicate the need for more information. | |
| Remember: Your primary goal is mood classification. Stay on topic and guide the conversation towards understanding the user's emotional state. | |
| """ | |
| prompt = f"{fixed_prompt}\n" | |
| for i, (u, b) in enumerate(history): | |
| prompt += f"User: {u}\nAssistant: {b}\n" | |
| if i == 3: | |
| prompt += "Note: This is the last exchange. Classify the mood if possible or respond with 'Unclear'.\n" | |
| prompt += f"User: {message}\nAssistant:" | |
| return prompt | |
| # 3) Call HF Invocation Endpoint | |
| def query_model(prompt, max_new_tokens=64, temperature=0.1): | |
| headers = { | |
| "Authorization": f"Bearer {HF_TOKEN}", | |
| "Content-Type": "application/json", | |
| } | |
| payload = { | |
| "inputs": prompt, | |
| "parameters": {"max_new_tokens": max_new_tokens, "temperature": temperature}, | |
| } | |
| resp = requests.post(ENDPOINT_URL, headers=headers, json=payload, timeout=30) | |
| resp.raise_for_status() | |
| return resp.json()[0]["generated_text"] | |
| # 4) Aggregate mood from history | |
| def aggregate_mood_from_history(history): | |
| mood_words = {"happy", "sad", "instrumental", "party"} | |
| counts = Counter() | |
| for _, bot_response in history: | |
| for tok in bot_response.split(): | |
| w = tok.strip('.,?!;"\'').lower() | |
| if w in mood_words: | |
| counts[w] += 1 | |
| if not counts: | |
| return None | |
| return counts.most_common(1)[0][0] | |
| # 5) TextβSpeech | |
| def text_to_speech(text): | |
| communicate = edge_tts.Communicate(text) | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp: | |
| # save synchronously to simplify callback | |
| asyncio.get_event_loop().run_until_complete(communicate.save(tmp.name)) | |
| return tmp.name | |
| # βββ Gradio Callbacks βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def user_turn(user_input, history): | |
| history = history + [(user_input, None)] | |
| formatted = format_prompt(user_input, history) | |
| raw = query_model(formatted) | |
| # temporarily assign raw | |
| history[-1] = (user_input, raw) | |
| # aggregate mood | |
| mood = aggregate_mood_from_history(history) | |
| if mood: | |
| reply = f"Playing {mood.capitalize()} playlist for you!" | |
| else: | |
| reply = raw | |
| history[-1] = (user_input, reply) | |
| return history, history, "" | |
| async def bot_audio(history): | |
| last = history[-1][1] | |
| return text_to_speech(last) | |
| def speech_callback(audio): | |
| return speech_to_text(audio) | |
| # βββ Build the Interface ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## π΅ Mood-Based Music Buddy") | |
| chat = gr.Chatbot() | |
| txt = gr.Textbox(placeholder="Type your mood...", label="Text") | |
| send = gr.Button("Send") | |
| mic = gr.Audio() | |
| out_audio = gr.Audio(label="Response (Audio)", autoplay=True) | |
| state = gr.State([]) | |
| def init(): | |
| greeting = "Hi! I'm your music buddyβtell me how youβre feeling today." | |
| return [("", greeting)], [("", greeting)], None | |
| demo.load(init, outputs=[state, chat, out_audio]) | |
| txt.submit(user_turn, [txt, state], [state, chat, txt])\ | |
| .then(bot_audio, [state], [out_audio]) | |
| send.click(user_turn, [txt, state], [state, chat, txt])\ | |
| .then(bot_audio, [state], [out_audio]) | |
| mic.change(speech_callback, [mic], [txt])\ | |
| .then(user_turn, [txt, state], [state, chat, txt])\ | |
| .then(bot_audio, [state], [out_audio]) | |
| if __name__ == "__main__": | |
| demo.launch(debug=True) | |
| # import gradio as gr | |
| # import requests | |
| # from transformers import pipeline | |
| # import edge_tts | |
| # import tempfile | |
| # import asyncio | |
| # import os | |
| # import json | |
| # import time | |
| # import logging | |
| # # Set up logging | |
| # logging.basicConfig(level=logging.INFO) | |
| # logger = logging.getLogger(__name__) | |
| # ENDPOINT_URL = "https://xzup8268xrmmxcma.us-east-1.aws.endpoints.huggingface.cloud/invocations" | |
| # hf_token = os.getenv("HF_TOKEN") | |
| # print(f"DEBUG: Starting application at {time.strftime('%Y-%m-%d %H:%M:%S')}") | |
| # print(f"DEBUG: HF_TOKEN available: {bool(hf_token)}") | |
| # print(f"DEBUG: Endpoint URL: {ENDPOINT_URL}") | |
| # try: | |
| # print("DEBUG: Loading ASR pipeline...") | |
| # start_time = time.time() | |
| # asr = pipeline("automatic-speech-recognition", "facebook/wav2vec2-base-960h") | |
| # print(f"DEBUG: ASR pipeline loaded in {time.time() - start_time:.2f} seconds") | |
| # except Exception as e: | |
| # print(f"DEBUG: Error loading ASR pipeline: {e}") | |
| # asr = None | |
| # INITIAL_MESSAGE = "Hi! I'm your music buddyβtell me about your mood and the type of tunes you're in the mood for today!" | |
| # def speech_to_text(speech): | |
| # print(f"DEBUG: speech_to_text called with input: {speech is not None}") | |
| # if speech is None: | |
| # print("DEBUG: No speech input provided") | |
| # return "" | |
| # try: | |
| # start_time = time.time() | |
| # print("DEBUG: Starting speech recognition...") | |
| # result = asr(speech)["text"] | |
| # print(f"DEBUG: Speech recognition completed in {time.time() - start_time:.2f} seconds") | |
| # print(f"DEBUG: Recognized text: '{result}'") | |
| # return result | |
| # except Exception as e: | |
| # print(f"DEBUG: Error in speech_to_text: {e}") | |
| # return "" | |
| # def classify_mood(input_string): | |
| # print(f"DEBUG: classify_mood called with: '{input_string}'") | |
| # input_string = input_string.lower() | |
| # mood_words = {"happy", "sad", "instrumental", "party"} | |
| # for word in mood_words: | |
| # if word in input_string: | |
| # print(f"DEBUG: Mood classified as: {word}") | |
| # return word, True | |
| # print("DEBUG: No mood classified") | |
| # return None, False | |
| # def generate(prompt, history, temperature=0.1, max_new_tokens=2048): | |
| # print(f"DEBUG: generate() called at {time.strftime('%H:%M:%S')}") | |
| # print(f"DEBUG: Prompt length: {len(prompt)}") | |
| # print(f"DEBUG: History length: {len(history)}") | |
| # if not hf_token: | |
| # error_msg = "Error: Hugging Face authentication required. Please set your HF_TOKEN." | |
| # print(f"DEBUG: {error_msg}") | |
| # return error_msg | |
| # try: | |
| # print("DEBUG: Formatting prompt...") | |
| # start_time = time.time() | |
| # formatted_prompt = format_prompt(prompt, history) | |
| # print(f"DEBUG: Prompt formatted in {time.time() - start_time:.2f} seconds") | |
| # print(f"DEBUG: Formatted prompt length: {len(formatted_prompt)}") | |
| # headers = {"Authorization": f"Bearer {hf_token}", "Content-Type": "application/json"} | |
| # payload = { | |
| # "inputs": formatted_prompt, | |
| # "parameters": { | |
| # "temperature": temperature, | |
| # "max_new_tokens": max_new_tokens | |
| # } | |
| # } | |
| # print("DEBUG: Making API request...") | |
| # api_start_time = time.time() | |
| # response = requests.post(ENDPOINT_URL, headers=headers, json=payload, timeout=60) | |
| # api_duration = time.time() - api_start_time | |
| # print(f"DEBUG: API request completed in {api_duration:.2f} seconds") | |
| # print(f"DEBUG: Response status code: {response.status_code}") | |
| # if response.status_code == 200: | |
| # print("DEBUG: Parsing API response...") | |
| # result = response.json() | |
| # output = result[0]["generated_text"] | |
| # print(f"DEBUG: Generated output: '{output[:100]}...'") | |
| # mood, is_classified = classify_mood(output) | |
| # if is_classified: | |
| # playlist_message = f"Playing {mood.capitalize()} playlist for you!" | |
| # print(f"DEBUG: Returning playlist message: {playlist_message}") | |
| # return playlist_message | |
| # print(f"DEBUG: Returning generated output") | |
| # return output | |
| # else: | |
| # error_msg = f"Error: {response.status_code} - {response.text}" | |
| # print(f"DEBUG: API error: {error_msg}") | |
| # return error_msg | |
| # except requests.exceptions.Timeout: | |
| # error_msg = "Error: API request timed out after 60 seconds" | |
| # print(f"DEBUG: {error_msg}") | |
| # return error_msg | |
| # except Exception as e: | |
| # error_msg = f"Error generating response: {str(e)}" | |
| # print(f"DEBUG: Exception in generate(): {error_msg}") | |
| # return error_msg | |
| # def format_prompt(message, history): | |
| # print("DEBUG: format_prompt called") | |
| # fixed_prompt = """ | |
| # You are a smart mood analyzer tasked with determining the user's mood for a music recommendation system. Your goal is to classify the user's mood into one of four categories: Happy, Sad, Instrumental, or Party. | |
| # Instructions: | |
| # 1. Engage in a conversation with the user to understand their mood. | |
| # 2. Ask relevant questions to guide the conversation towards mood classification. | |
| # 3. If the user's mood is clear, respond with a single word: "Happy", "Sad", "Instrumental", or "Party". | |
| # 4. If the mood is unclear, continue the conversation with a follow-up question. | |
| # 5. Limit the conversation to a maximum of 5 exchanges. | |
| # 6. Do not classify the mood prematurely if it's not evident from the user's responses. | |
| # 7. Focus on the user's emotional state rather than specific activities or preferences. | |
| # 8. If unable to classify after 5 exchanges, respond with "Unclear" to indicate the need for more information. | |
| # Remember: Your primary goal is mood classification. Stay on topic and guide the conversation towards understanding the user's emotional state. | |
| # """ | |
| # prompt = f"{fixed_prompt}\n" | |
| # for i, (user_prompt, bot_response) in enumerate(history): | |
| # prompt += f"User: {user_prompt}\nAssistant: {bot_response}\n" | |
| # if i == 3: | |
| # prompt += "Note: This is the last exchange. Classify the mood if possible or respond with 'Unclear'.\n" | |
| # prompt += f"User: {message}\nAssistant:" | |
| # print(f"DEBUG: Final prompt length: {len(prompt)}") | |
| # return prompt | |
| # async def text_to_speech(text): | |
| # print(f"DEBUG: text_to_speech called with text length: {len(text)}") | |
| # try: | |
| # start_time = time.time() | |
| # print("DEBUG: Creating TTS communicate object...") | |
| # communicate = edge_tts.Communicate(text) | |
| # print("DEBUG: Creating temporary file...") | |
| # with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file: | |
| # tmp_path = tmp_file.name | |
| # print(f"DEBUG: Saving TTS to: {tmp_path}") | |
| # await communicate.save(tmp_path) | |
| # duration = time.time() - start_time | |
| # print(f"DEBUG: TTS completed in {duration:.2f} seconds") | |
| # print(f"DEBUG: TTS file size: {os.path.getsize(tmp_path) if os.path.exists(tmp_path) else 'File not found'}") | |
| # return tmp_path | |
| # except Exception as e: | |
| # print(f"DEBUG: TTS Error: {e}") | |
| # return None | |
| # def process_input(input_text, history): | |
| # print(f"DEBUG: process_input called with text: '{input_text[:50]}...'") | |
| # if not input_text: | |
| # print("DEBUG: No input text provided") | |
| # return history, history, "" | |
| # print("DEBUG: Calling generate function...") | |
| # start_time = time.time() | |
| # response = generate(input_text, history) | |
| # duration = time.time() - start_time | |
| # print(f"DEBUG: generate() completed in {duration:.2f} seconds") | |
| # print(f"DEBUG: Response: '{response[:100]}...'") | |
| # history.append((input_text, response)) | |
| # print(f"DEBUG: Updated history length: {len(history)}") | |
| # return history, history, "" | |
| # async def generate_audio(history): | |
| # print(f"DEBUG: generate_audio called with history length: {len(history)}") | |
| # if history and len(history) > 0: | |
| # last_response = history[-1][1] | |
| # print(f"DEBUG: Generating audio for: '{last_response[:50]}...'") | |
| # start_time = time.time() | |
| # audio_path = await text_to_speech(last_response) | |
| # duration = time.time() - start_time | |
| # print(f"DEBUG: Audio generation completed in {duration:.2f} seconds") | |
| # return audio_path | |
| # print("DEBUG: No history available for audio generation") | |
| # return None | |
| # async def init_chat(): | |
| # print("DEBUG: init_chat called") | |
| # try: | |
| # history = [("", INITIAL_MESSAGE)] | |
| # print("DEBUG: Generating initial audio...") | |
| # start_time = time.time() | |
| # audio_path = await text_to_speech(INITIAL_MESSAGE) | |
| # duration = time.time() - start_time | |
| # print(f"DEBUG: Initial audio generated in {duration:.2f} seconds") | |
| # print("DEBUG: init_chat completed successfully") | |
| # return history, history, audio_path | |
| # except Exception as e: | |
| # print(f"DEBUG: Error in init_chat: {e}") | |
| # return [("", INITIAL_MESSAGE)], [("", INITIAL_MESSAGE)], None | |
| # def handle_voice_upload(audio_file): | |
| # print(f"DEBUG: handle_voice_upload called with file: {audio_file}") | |
| # if audio_file is None: | |
| # print("DEBUG: No audio file provided") | |
| # return "" | |
| # try: | |
| # start_time = time.time() | |
| # result = speech_to_text(audio_file) | |
| # duration = time.time() - start_time | |
| # print(f"DEBUG: Voice upload processing completed in {duration:.2f} seconds") | |
| # return result | |
| # except Exception as e: | |
| # print(f"DEBUG: Error in handle_voice_upload: {e}") | |
| # return "" | |
| # print("DEBUG: Creating Gradio interface...") | |
| # with gr.Blocks() as demo: | |
| # gr.Markdown("# Mood-Based Music Recommender with Continuous Voice Chat") | |
| # chatbot = gr.Chatbot() | |
| # with gr.Row(): | |
| # msg = gr.Textbox( | |
| # placeholder="Type your message here...", | |
| # label="Text Input", | |
| # scale=4 | |
| # ) | |
| # submit = gr.Button("Send", scale=1) | |
| # with gr.Row(): | |
| # voice_input = gr.Audio( | |
| # label="π€ Record your voice or upload audio file", | |
| # sources=["microphone", "upload"], | |
| # type="filepath" | |
| # ) | |
| # audio_output = gr.Audio(label="AI Response", autoplay=True) | |
| # state = gr.State([]) | |
| # print("DEBUG: Setting up Gradio event handlers...") | |
| # demo.load(init_chat, outputs=[state, chatbot, audio_output]) | |
| # def submit_and_generate_audio(input_text, history): | |
| # print(f"DEBUG: submit_and_generate_audio called at {time.strftime('%H:%M:%S')}") | |
| # start_time = time.time() | |
| # new_state, new_chatbot, empty_msg = process_input(input_text, history) | |
| # duration = time.time() - start_time | |
| # print(f"DEBUG: submit_and_generate_audio completed in {duration:.2f} seconds") | |
| # return new_state, new_chatbot, empty_msg | |
| # msg.submit( | |
| # submit_and_generate_audio, | |
| # inputs=[msg, state], | |
| # outputs=[state, chatbot, msg] | |
| # ).then( | |
| # generate_audio, | |
| # inputs=[state], | |
| # outputs=[audio_output] | |
| # ) | |
| # submit.click( | |
| # submit_and_generate_audio, | |
| # inputs=[msg, state], | |
| # outputs=[state, chatbot, msg] | |
| # ).then( | |
| # generate_audio, | |
| # inputs=[state], | |
| # outputs=[audio_output] | |
| # ) | |
| # voice_input.upload( | |
| # handle_voice_upload, | |
| # inputs=[voice_input], | |
| # outputs=[msg] | |
| # ).then( | |
| # submit_and_generate_audio, | |
| # inputs=[msg, state], | |
| # outputs=[state, chatbot, msg] | |
| # ).then( | |
| # generate_audio, | |
| # inputs=[state], | |
| # outputs=[audio_output] | |
| # ) | |
| # print("DEBUG: Gradio interface created successfully") | |
| # if __name__ == "__main__": | |
| # print("DEBUG: Launching Gradio app...") | |
| # demo.launch(share=True, debug=True) |