Spaces:
Paused
Paused
| import gradio as gr | |
| import asyncio | |
| import edge_tts | |
| import os | |
| from huggingface_hub import InferenceClient | |
| import requests | |
| import tempfile | |
| import logging | |
| import io | |
| from pydub import AudioSegment | |
| # Set up logging | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
| # Get the Hugging Face token from environment variable | |
| hf_token = os.getenv("HF_TOKEN") | |
| if not hf_token: | |
| raise ValueError("HF_TOKEN environment variable is not set") | |
| # Initialize the Hugging Face Inference Client for chat completion | |
| chat_client = InferenceClient("mistralai/Mistral-Nemo-Instruct-2407", token=hf_token) | |
| # Whisper API settings | |
| WHISPER_API_URL = "https://api-inference.huggingface.co/models/openai/whisper-large-v3-turbo" | |
| headers = {"Authorization": f"Bearer {hf_token}"} | |
| # Initialize an empty chat history | |
| chat_history = [] | |
| async def text_to_speech_stream(text, voice_volume=1.0): | |
| """Convert text to speech using edge_tts and return the audio file path.""" | |
| communicate = edge_tts.Communicate(text, "en-US-BrianMultilingualNeural") | |
| audio_data = b"" | |
| async for chunk in communicate.stream(): | |
| if chunk["type"] == "audio": | |
| audio_data += chunk["data"] | |
| # Adjust volume | |
| audio = AudioSegment.from_mp3(io.BytesIO(audio_data)) | |
| adjusted_audio = audio + (20 * voice_volume - 20) # Adjust volume (0.0 to 2.0) | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as temp_file: | |
| adjusted_audio.export(temp_file.name, format="mp3") | |
| return temp_file.name | |
| def whisper_speech_to_text(audio_path): | |
| """Convert speech to text using Hugging Face Whisper API.""" | |
| if audio_path is None: | |
| logging.error("Error: No audio file provided") | |
| return "" | |
| if not os.path.exists(audio_path): | |
| logging.error(f"Error: Audio file not found at {audio_path}") | |
| return "" | |
| try: | |
| with open(audio_path, "rb") as audio_file: | |
| data = audio_file.read() | |
| response = requests.post(WHISPER_API_URL, headers=headers, data=data) | |
| response.raise_for_status() # Raise an exception for bad status codes | |
| result = response.json() | |
| transcribed_text = result.get("text", "") | |
| logging.info(f"Transcribed text: {transcribed_text}") | |
| return transcribed_text | |
| except requests.exceptions.RequestException as e: | |
| logging.error(f"Error during API request: {e}") | |
| return "" | |
| except Exception as e: | |
| logging.error(f"Unexpected error in whisper_speech_to_text: {e}") | |
| return "" | |
| async def chat_with_ai(message): | |
| global chat_history | |
| chat_history.append({"role": "user", "content": message}) | |
| try: | |
| response = chat_client.chat_completion( | |
| messages=[{"role": "system", "content": "You are a helpful voice assistant. Provide concise and clear responses to user queries."}] + chat_history, | |
| max_tokens=800, | |
| temperature=0.7 | |
| ) | |
| response_text = response.choices[0].message['content'] | |
| chat_history.append({"role": "assistant", "content": response_text}) | |
| audio_path = await text_to_speech_stream(response_text) | |
| return response_text, audio_path | |
| except Exception as e: | |
| logging.error(f"Error in chat_with_ai: {e}") | |
| return str(e), None | |
| def transcribe_and_chat(audio): | |
| if audio is None: | |
| return "Sorry, no audio was provided. Please try recording again.", None | |
| text = whisper_speech_to_text(audio) | |
| if not text: | |
| return "Sorry, I couldn't understand the audio or there was an error in transcription. Please try again.", None | |
| response, audio_path = asyncio.run(chat_with_ai(text)) | |
| return response, audio_path | |
| def create_demo(): | |
| with gr.Blocks(css=""" | |
| @import url('https://fonts.googleapis.com/css2?family=Poppins:wght@400;500;700&display=swap'); | |
| body { font-family: 'Poppins', sans-serif; margin: 0; padding: 0; box-sizing: border-box;} | |
| #audio-input {border: 2px solid #ffb703; padding: 10px;} | |
| #chat-output {background-color: #023047; color: #ffffff; font-size: 1.2em;} | |
| #audio-output {border: 2px solid #8ecae6;} | |
| #clear-button {background-color: #fb8500; color: white;} | |
| #voice-volume {background-color: #219ebc;} | |
| button {font-size: 16px;} | |
| audio {background-color: #ffb703; border-radius: 10px;} | |
| footer {display: none;} | |
| @media (max-width: 768px) { | |
| #audio-input, #chat-output, #audio-output { width: 100%; } | |
| button { width: 100%; } | |
| } | |
| """) as demo: | |
| gr.Markdown( | |
| """ | |
| <div style='text-align:center; color:#023047; font-size: 28px; font-weight: bold;'>๐ฃ๏ธ AI Voice Assistant</div> | |
| <p style='text-align:center; color:#8ecae6; font-size: 18px;'>Talk to your personal AI! Record your voice, and get a response in both text and audio.</p> | |
| <p style='text-align:center; color:#8ecae6;'>Powered by advanced AI models for real-time interaction.</p> | |
| """, | |
| elem_id="header" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| audio_input = gr.Audio(type="filepath", label="๐ค Record your voice", elem_id="audio-input") | |
| clear_button = gr.Button("Clear", variant="secondary", elem_id="clear-button") | |
| voice_volume = gr.Slider(minimum=0, maximum=2, value=1, step=0.1, label="Voice Volume", elem_id="voice-volume") | |
| with gr.Column(scale=1): | |
| chat_output = gr.Textbox(label="๐ฌ AI Response", elem_id="chat-output", lines=5, interactive=False) | |
| audio_output = gr.Audio(label="๐ AI Voice Response", autoplay=True, elem_id="audio-output") | |
| # Add some spacing and a divider | |
| gr.Markdown("<hr style='border: 1px solid #8ecae6;'/>") | |
| # Processing the audio input | |
| def process_audio(audio, volume): | |
| logging.info(f"Received audio: {audio}") | |
| if audio is None: | |
| return "No audio detected. Please try recording again.", None | |
| response, audio_path = transcribe_and_chat(audio) | |
| # Adjust volume for the response audio | |
| adjusted_audio_path = asyncio.run(text_to_speech_stream(response, volume)) | |
| logging.info(f"Response: {response}, Audio path: {adjusted_audio_path}") | |
| return response, adjusted_audio_path | |
| audio_input.change(process_audio, inputs=[audio_input, voice_volume], outputs=[chat_output, audio_output]) | |
| clear_button.click(lambda: (None, None), None, [chat_output, audio_output]) | |
| # JavaScript to handle autoplay, automatic submission, and auto-listen | |
| demo.load(None, js=""" | |
| function() { | |
| var recordButton; | |
| function findRecordButton() { | |
| var buttons = document.querySelectorAll('button'); | |
| for (var i = 0; i < buttons.length; i++) { | |
| if (buttons[i].textContent.includes('Record from microphone')) { | |
| return buttons[i]; | |
| } | |
| } | |
| return null; | |
| } | |
| function startListening() { | |
| if (!recordButton) { | |
| recordButton = findRecordButton(); | |
| } | |
| if (recordButton) { | |
| recordButton.click(); | |
| } | |
| } | |
| document.querySelector("audio").addEventListener("ended", function() { | |
| setTimeout(startListening, 500); | |
| }); | |
| function playAssistantAudio() { | |
| var audioElements = document.querySelectorAll('audio'); | |
| if (audioElements.length > 1) { | |
| var assistantAudio = audioElements[1]; | |
| if (assistantAudio) { | |
| assistantAudio.play(); | |
| } | |
| } | |
| } | |
| document.addEventListener('gradioAudioLoaded', function(event) { | |
| playAssistantAudio(); | |
| }); | |
| document.addEventListener('gradioUpdated', function(event) { | |
| setTimeout(playAssistantAudio, 100); | |
| }); | |
| // Prevent audio from stopping when switching tabs | |
| document.addEventListener("visibilitychange", function() { | |
| var audioElements = document.querySelectorAll('audio'); | |
| audioElements.forEach(function(audio) { | |
| audio.play(); | |
| }); | |
| }); | |
| } | |
| """) | |
| return demo | |
| # Launch the Gradio app | |
| if __name__ == "__main__": | |
| demo = create_demo() | |
| demo.launch(server_name="0.0.0.0", server_port=7860) |