Spaces:
Build error
Build error
| import os | |
| import logging | |
| import time | |
| import asyncio | |
| import gradio as gr | |
| from dotenv import load_dotenv | |
| from livekit import rtc, agents | |
| from openai import AsyncOpenAI | |
| from typing import Dict, Any, Optional, List, Callable | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Load environment variables | |
| load_dotenv() | |
| # Prompt for the speech therapist agent | |
| SPEECH_THERAPIST_PROMPT = """ | |
| You are a speech pathologist, a healthcare professional who specializes in evaluating, diagnosing, and treating communication disorders, including speech, language, cognitive-communication, voice, and fluency disorders. Your role is to help patients improve their speech and communication skills through various therapeutic techniques and exercises. | |
| Your are working with a student with speech impediments typically with ASD. | |
| You have to be rigid to help them stay on the right track. You have to start with some sort of intro activity and can not rely on the student at all to complete your thoughts. You pick a place to start and assess the speech from there. | |
| IMPORTANT GUIDELINES: | |
| 1. Keep responses concise and clear - students may have difficulty with long explanations | |
| 2. Use simple language when possible but don't talk down to the student | |
| 3. When a student seems confused, offer to repeat what you said | |
| 4. Start with a brief introduction and a simple activity | |
| 5. Maintain a structured, predictable interaction pattern | |
| 6. Be patient and positive, but maintain expectations | |
| 7. If the student asks you to repeat something, say "I'll repeat that" and then repeat your previous message clearly | |
| 8. Always wait for the student to finish speaking before responding | |
| 9. Keep your audio responses under 30 seconds to reduce cognitive load | |
| 10. Be responsive to signs of frustration or confusion | |
| Start with a brief introduction and a simple articulation exercise to assess the student's baseline skills. | |
| """ | |
| class SpeechTherapistAgent: | |
| """ | |
| AI Speech Therapist agent that uses LiveKit and OpenAI to create a voice-to-voice interaction | |
| with low latency for neurodivergent students. | |
| """ | |
| def __init__(self): | |
| """Initialize the speech therapist agent with OpenAI client and essential variables.""" | |
| self._setup_openai_client() | |
| self.last_response = "" | |
| self.conversation_history = [] | |
| self.is_processing = False | |
| self.last_processed_time = 0 | |
| self.MIN_PROCESSING_INTERVAL = 1.0 # Minimum seconds between processing to reduce latency | |
| def _setup_openai_client(self): | |
| """Set up the OpenAI client with API key from environment variables.""" | |
| api_key = os.getenv("OPENAI_API_KEY") | |
| if not api_key: | |
| raise ValueError("OPENAI_API_KEY environment variable is not set") | |
| self.client = AsyncOpenAI(api_key=api_key) | |
| async def process_audio(self, audio_data: bytes) -> Optional[str]: | |
| """ | |
| Process audio input from the user and get a response from the AI. | |
| Implements debouncing to reduce latency. | |
| """ | |
| current_time = time.time() | |
| # Debounce to avoid processing too frequently | |
| if self.is_processing or (current_time - self.last_processed_time < self.MIN_PROCESSING_INTERVAL): | |
| return None | |
| self.is_processing = True | |
| try: | |
| # Transcribe audio using OpenAI Whisper | |
| transcript = await self._transcribe_audio(audio_data) | |
| if not transcript: | |
| return None | |
| logger.info(f"Transcribed: {transcript}") | |
| # Skip processing if the transcript is too short or empty | |
| if len(transcript.strip()) < 2: | |
| return None | |
| # Get AI response | |
| response = await self._get_ai_response(transcript) | |
| self.last_response = response | |
| # Update conversation history | |
| self.conversation_history.append({"role": "user", "content": transcript}) | |
| self.conversation_history.append({"role": "assistant", "content": response}) | |
| # Trim history if it gets too long | |
| if len(self.conversation_history) > 20: | |
| self.conversation_history = self.conversation_history[-20:] | |
| return response | |
| except Exception as e: | |
| logger.error(f"Error processing audio: {e}") | |
| return "I'm sorry, I had trouble understanding. Could you please repeat that?" | |
| finally: | |
| self.is_processing = False | |
| self.last_processed_time = time.time() | |
| async def _transcribe_audio(self, audio_data: bytes) -> str: | |
| """Transcribe audio using OpenAI Whisper.""" | |
| try: | |
| # Save audio to a temporary file | |
| temp_filename = "temp_audio.wav" | |
| with open(temp_filename, "wb") as f: | |
| f.write(audio_data) | |
| # Transcribe using OpenAI | |
| with open(temp_filename, "rb") as audio_file: | |
| transcript = await self.client.audio.transcriptions.create( | |
| model="whisper-1", | |
| file=audio_file | |
| ) | |
| # Clean up temp file | |
| os.remove(temp_filename) | |
| return transcript.text | |
| except Exception as e: | |
| logger.error(f"Transcription error: {e}") | |
| return "" | |
| async def _get_ai_response(self, user_input: str) -> str: | |
| """Get AI response using OpenAI's chat completion.""" | |
| messages = [{"role": "system", "content": SPEECH_THERAPIST_PROMPT}] | |
| # Add conversation history | |
| for message in self.conversation_history[-6:]: # Only use last 6 messages for context | |
| messages.append(message) | |
| # Add current user input | |
| messages.append({"role": "user", "content": user_input}) | |
| # Check for repeat request | |
| if "repeat" in user_input.lower() or "say again" in user_input.lower() or "what did you say" in user_input.lower(): | |
| if self.last_response: | |
| return f"I'll repeat that. {self.last_response}" | |
| else: | |
| return "I don't have anything to repeat yet. Let's start our session." | |
| # Get response from OpenAI | |
| response = await self.client.chat.completions.create( | |
| model="gpt-4o", # Using GPT-4o for best balance of quality and latency | |
| messages=messages, | |
| max_tokens=150, # Keeping responses concise for lower latency | |
| temperature=0.7 | |
| ) | |
| return response.choices[0].message.content | |
| def repeat_last_response(self) -> str: | |
| """Repeat the last response for the student.""" | |
| if self.last_response: | |
| return f"I'll repeat that. {self.last_response}" | |
| return "I don't have anything to repeat yet. Let's start our session." | |
| class LiveKitVoiceAgent: | |
| """ | |
| LiveKit voice agent that handles real-time audio communication. | |
| """ | |
| def __init__(self, room_url: str, token: str): | |
| """Initialize the LiveKit voice agent.""" | |
| self.room_url = room_url | |
| self.token = token | |
| self.room = None | |
| self.audio_source = None | |
| self.speech_therapist = SpeechTherapistAgent() | |
| self.tts_agent = None | |
| async def connect(self): | |
| """Connect to the LiveKit room.""" | |
| try: | |
| # Create and connect to the room | |
| self.room = rtc.Room() | |
| await self.room.connect(self.room_url, self.token) | |
| logger.info(f"Connected to room: {self.room.name}") | |
| # Set up audio capture | |
| self.audio_source = agents.AudioSource(self.room) | |
| # Set up text-to-speech agent | |
| self.tts_agent = agents.OpenAITTSAgent( | |
| api_key=os.getenv("OPENAI_API_KEY"), | |
| room=self.room, | |
| voice="alloy", # Using a clear voice that works well for neurodivergent students | |
| model="tts-1", # Using the fastest model for lower latency | |
| ) | |
| # Start the initialization message | |
| await self._speak_initial_message() | |
| # Start processing audio | |
| await self._process_audio_stream() | |
| except Exception as e: | |
| logger.error(f"Error connecting to LiveKit: {e}") | |
| raise | |
| async def _speak_initial_message(self): | |
| """Speak the initial greeting message.""" | |
| initial_message = "Hello! I'm your speech therapist today. We'll be doing some simple speech exercises together. I'm here to help you practice. Let's start with an introduction. Can you tell me your name?" | |
| await self.tts_agent.speak(initial_message) | |
| self.speech_therapist.last_response = initial_message | |
| async def _process_audio_stream(self): | |
| """Process the audio stream continuously.""" | |
| buffer = bytearray() | |
| async for frame in self.audio_source: | |
| # Add frame to buffer | |
| buffer.extend(frame.data) | |
| # Process when buffer has enough data (approximately 2 seconds of audio) | |
| if len(buffer) >= 32000: # 16kHz * 2 seconds * 1 byte per sample | |
| audio_data = bytes(buffer) | |
| buffer.clear() | |
| # Process audio asynchronously | |
| response = await self.speech_therapist.process_audio(audio_data) | |
| # If we have a response, speak it | |
| if response: | |
| await self.tts_agent.speak(response) | |
| async def disconnect(self): | |
| """Disconnect from the LiveKit room.""" | |
| if self.room: | |
| await self.room.disconnect() | |
| logger.info("Disconnected from room") | |
| def create_gradio_interface(): | |
| """Create a Gradio interface for the Speech Therapist application.""" | |
| # Define CSS for better UI | |
| css = """ | |
| .container { | |
| max-width: 800px; | |
| margin: auto; | |
| padding: 20px; | |
| } | |
| .title { | |
| text-align: center; | |
| margin-bottom: 20px; | |
| } | |
| .instructions { | |
| background-color: #f9f9f9; | |
| padding: 15px; | |
| border-radius: 10px; | |
| margin-bottom: 20px; | |
| } | |
| """ | |
| # State variables | |
| agent = None | |
| task = None | |
| async def start_session(room_url, token): | |
| """Start the LiveKit session.""" | |
| nonlocal agent, task | |
| if not room_url or not token: | |
| return "Please provide both LiveKit room URL and token." | |
| try: | |
| # Create and start the agent | |
| agent = LiveKitVoiceAgent(room_url, token) | |
| task = asyncio.create_task(agent.connect()) | |
| return "Session started! The AI speech therapist is now listening." | |
| except Exception as e: | |
| logger.error(f"Error starting session: {e}") | |
| return f"Error: {str(e)}" | |
| async def end_session(): | |
| """End the LiveKit session.""" | |
| nonlocal agent, task | |
| if agent: | |
| await agent.disconnect() | |
| if task: | |
| task.cancel() | |
| agent = None | |
| task = None | |
| return "Session ended." | |
| return "No active session to end." | |
| async def repeat_last(): | |
| """Repeat the last response.""" | |
| nonlocal agent | |
| if agent and agent.speech_therapist: | |
| response = agent.speech_therapist.repeat_last_response() | |
| if agent.tts_agent: | |
| await agent.tts_agent.speak(response) | |
| return "Repeating last response." | |
| return "No active session or no response to repeat." | |
| # Create the Gradio interface | |
| with gr.Blocks(css=css) as interface: | |
| gr.Markdown("# AI Speech Therapist", elem_classes=["title"]) | |
| with gr.Row(elem_classes=["instructions"]): | |
| gr.Markdown(""" | |
| ## Instructions | |
| 1. Enter your LiveKit room URL and token | |
| 2. Click "Start Session" to begin | |
| 3. Speak clearly into your microphone | |
| 4. Use the "Repeat" button if you need the AI to repeat its last response | |
| 5. Click "End Session" when finished | |
| This tool is designed for students with speech impediments, particularly those with ASD. | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| room_url = gr.Textbox(label="LiveKit Room URL", placeholder="wss://your-livekit-instance.com") | |
| token = gr.Textbox(label="LiveKit Token", placeholder="Your LiveKit token") | |
| with gr.Row(): | |
| start_btn = gr.Button("Start Session", variant="primary") | |
| repeat_btn = gr.Button("Repeat Last Response", variant="secondary") | |
| end_btn = gr.Button("End Session", variant="stop") | |
| status = gr.Textbox(label="Status", value="Enter LiveKit details and click Start Session") | |
| # Connect buttons to functions | |
| start_btn.click(start_session, inputs=[room_url, token], outputs=status) | |
| repeat_btn.click(repeat_last, outputs=status) | |
| end_btn.click(end_session, outputs=status) | |
| return interface | |
| def main(): | |
| """Main function to run the Gradio interface.""" | |
| interface = create_gradio_interface() | |
| # Launch the Gradio app | |
| interface.launch( | |
| share=True, # For testing - creates a public link | |
| enable_queue=True, # Handles multiple users | |
| server_name="0.0.0.0", # Binds to all interfaces | |
| server_port=7860, # Default Gradio port | |
| ) | |
| if __name__ == "__main__": | |
| main() |