CASLAppGradio / app.py
SreekarB's picture
Create app.py
b6c64c2 verified
import os
import logging
import time
import asyncio
import gradio as gr
from dotenv import load_dotenv
from livekit import rtc, agents
from openai import AsyncOpenAI
from typing import Dict, Any, Optional, List, Callable
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Load environment variables
load_dotenv()
# Prompt for the speech therapist agent
SPEECH_THERAPIST_PROMPT = """
You are a speech pathologist, a healthcare professional who specializes in evaluating, diagnosing, and treating communication disorders, including speech, language, cognitive-communication, voice, and fluency disorders. Your role is to help patients improve their speech and communication skills through various therapeutic techniques and exercises.
Your are working with a student with speech impediments typically with ASD.
You have to be rigid to help them stay on the right track. You have to start with some sort of intro activity and can not rely on the student at all to complete your thoughts. You pick a place to start and assess the speech from there.
IMPORTANT GUIDELINES:
1. Keep responses concise and clear - students may have difficulty with long explanations
2. Use simple language when possible but don't talk down to the student
3. When a student seems confused, offer to repeat what you said
4. Start with a brief introduction and a simple activity
5. Maintain a structured, predictable interaction pattern
6. Be patient and positive, but maintain expectations
7. If the student asks you to repeat something, say "I'll repeat that" and then repeat your previous message clearly
8. Always wait for the student to finish speaking before responding
9. Keep your audio responses under 30 seconds to reduce cognitive load
10. Be responsive to signs of frustration or confusion
Start with a brief introduction and a simple articulation exercise to assess the student's baseline skills.
"""
class SpeechTherapistAgent:
"""
AI Speech Therapist agent that uses LiveKit and OpenAI to create a voice-to-voice interaction
with low latency for neurodivergent students.
"""
def __init__(self):
"""Initialize the speech therapist agent with OpenAI client and essential variables."""
self._setup_openai_client()
self.last_response = ""
self.conversation_history = []
self.is_processing = False
self.last_processed_time = 0
self.MIN_PROCESSING_INTERVAL = 1.0 # Minimum seconds between processing to reduce latency
def _setup_openai_client(self):
"""Set up the OpenAI client with API key from environment variables."""
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
raise ValueError("OPENAI_API_KEY environment variable is not set")
self.client = AsyncOpenAI(api_key=api_key)
async def process_audio(self, audio_data: bytes) -> Optional[str]:
"""
Process audio input from the user and get a response from the AI.
Implements debouncing to reduce latency.
"""
current_time = time.time()
# Debounce to avoid processing too frequently
if self.is_processing or (current_time - self.last_processed_time < self.MIN_PROCESSING_INTERVAL):
return None
self.is_processing = True
try:
# Transcribe audio using OpenAI Whisper
transcript = await self._transcribe_audio(audio_data)
if not transcript:
return None
logger.info(f"Transcribed: {transcript}")
# Skip processing if the transcript is too short or empty
if len(transcript.strip()) < 2:
return None
# Get AI response
response = await self._get_ai_response(transcript)
self.last_response = response
# Update conversation history
self.conversation_history.append({"role": "user", "content": transcript})
self.conversation_history.append({"role": "assistant", "content": response})
# Trim history if it gets too long
if len(self.conversation_history) > 20:
self.conversation_history = self.conversation_history[-20:]
return response
except Exception as e:
logger.error(f"Error processing audio: {e}")
return "I'm sorry, I had trouble understanding. Could you please repeat that?"
finally:
self.is_processing = False
self.last_processed_time = time.time()
async def _transcribe_audio(self, audio_data: bytes) -> str:
"""Transcribe audio using OpenAI Whisper."""
try:
# Save audio to a temporary file
temp_filename = "temp_audio.wav"
with open(temp_filename, "wb") as f:
f.write(audio_data)
# Transcribe using OpenAI
with open(temp_filename, "rb") as audio_file:
transcript = await self.client.audio.transcriptions.create(
model="whisper-1",
file=audio_file
)
# Clean up temp file
os.remove(temp_filename)
return transcript.text
except Exception as e:
logger.error(f"Transcription error: {e}")
return ""
async def _get_ai_response(self, user_input: str) -> str:
"""Get AI response using OpenAI's chat completion."""
messages = [{"role": "system", "content": SPEECH_THERAPIST_PROMPT}]
# Add conversation history
for message in self.conversation_history[-6:]: # Only use last 6 messages for context
messages.append(message)
# Add current user input
messages.append({"role": "user", "content": user_input})
# Check for repeat request
if "repeat" in user_input.lower() or "say again" in user_input.lower() or "what did you say" in user_input.lower():
if self.last_response:
return f"I'll repeat that. {self.last_response}"
else:
return "I don't have anything to repeat yet. Let's start our session."
# Get response from OpenAI
response = await self.client.chat.completions.create(
model="gpt-4o", # Using GPT-4o for best balance of quality and latency
messages=messages,
max_tokens=150, # Keeping responses concise for lower latency
temperature=0.7
)
return response.choices[0].message.content
def repeat_last_response(self) -> str:
"""Repeat the last response for the student."""
if self.last_response:
return f"I'll repeat that. {self.last_response}"
return "I don't have anything to repeat yet. Let's start our session."
class LiveKitVoiceAgent:
"""
LiveKit voice agent that handles real-time audio communication.
"""
def __init__(self, room_url: str, token: str):
"""Initialize the LiveKit voice agent."""
self.room_url = room_url
self.token = token
self.room = None
self.audio_source = None
self.speech_therapist = SpeechTherapistAgent()
self.tts_agent = None
async def connect(self):
"""Connect to the LiveKit room."""
try:
# Create and connect to the room
self.room = rtc.Room()
await self.room.connect(self.room_url, self.token)
logger.info(f"Connected to room: {self.room.name}")
# Set up audio capture
self.audio_source = agents.AudioSource(self.room)
# Set up text-to-speech agent
self.tts_agent = agents.OpenAITTSAgent(
api_key=os.getenv("OPENAI_API_KEY"),
room=self.room,
voice="alloy", # Using a clear voice that works well for neurodivergent students
model="tts-1", # Using the fastest model for lower latency
)
# Start the initialization message
await self._speak_initial_message()
# Start processing audio
await self._process_audio_stream()
except Exception as e:
logger.error(f"Error connecting to LiveKit: {e}")
raise
async def _speak_initial_message(self):
"""Speak the initial greeting message."""
initial_message = "Hello! I'm your speech therapist today. We'll be doing some simple speech exercises together. I'm here to help you practice. Let's start with an introduction. Can you tell me your name?"
await self.tts_agent.speak(initial_message)
self.speech_therapist.last_response = initial_message
async def _process_audio_stream(self):
"""Process the audio stream continuously."""
buffer = bytearray()
async for frame in self.audio_source:
# Add frame to buffer
buffer.extend(frame.data)
# Process when buffer has enough data (approximately 2 seconds of audio)
if len(buffer) >= 32000: # 16kHz * 2 seconds * 1 byte per sample
audio_data = bytes(buffer)
buffer.clear()
# Process audio asynchronously
response = await self.speech_therapist.process_audio(audio_data)
# If we have a response, speak it
if response:
await self.tts_agent.speak(response)
async def disconnect(self):
"""Disconnect from the LiveKit room."""
if self.room:
await self.room.disconnect()
logger.info("Disconnected from room")
def create_gradio_interface():
"""Create a Gradio interface for the Speech Therapist application."""
# Define CSS for better UI
css = """
.container {
max-width: 800px;
margin: auto;
padding: 20px;
}
.title {
text-align: center;
margin-bottom: 20px;
}
.instructions {
background-color: #f9f9f9;
padding: 15px;
border-radius: 10px;
margin-bottom: 20px;
}
"""
# State variables
agent = None
task = None
async def start_session(room_url, token):
"""Start the LiveKit session."""
nonlocal agent, task
if not room_url or not token:
return "Please provide both LiveKit room URL and token."
try:
# Create and start the agent
agent = LiveKitVoiceAgent(room_url, token)
task = asyncio.create_task(agent.connect())
return "Session started! The AI speech therapist is now listening."
except Exception as e:
logger.error(f"Error starting session: {e}")
return f"Error: {str(e)}"
async def end_session():
"""End the LiveKit session."""
nonlocal agent, task
if agent:
await agent.disconnect()
if task:
task.cancel()
agent = None
task = None
return "Session ended."
return "No active session to end."
async def repeat_last():
"""Repeat the last response."""
nonlocal agent
if agent and agent.speech_therapist:
response = agent.speech_therapist.repeat_last_response()
if agent.tts_agent:
await agent.tts_agent.speak(response)
return "Repeating last response."
return "No active session or no response to repeat."
# Create the Gradio interface
with gr.Blocks(css=css) as interface:
gr.Markdown("# AI Speech Therapist", elem_classes=["title"])
with gr.Row(elem_classes=["instructions"]):
gr.Markdown("""
## Instructions
1. Enter your LiveKit room URL and token
2. Click "Start Session" to begin
3. Speak clearly into your microphone
4. Use the "Repeat" button if you need the AI to repeat its last response
5. Click "End Session" when finished
This tool is designed for students with speech impediments, particularly those with ASD.
""")
with gr.Row():
with gr.Column():
room_url = gr.Textbox(label="LiveKit Room URL", placeholder="wss://your-livekit-instance.com")
token = gr.Textbox(label="LiveKit Token", placeholder="Your LiveKit token")
with gr.Row():
start_btn = gr.Button("Start Session", variant="primary")
repeat_btn = gr.Button("Repeat Last Response", variant="secondary")
end_btn = gr.Button("End Session", variant="stop")
status = gr.Textbox(label="Status", value="Enter LiveKit details and click Start Session")
# Connect buttons to functions
start_btn.click(start_session, inputs=[room_url, token], outputs=status)
repeat_btn.click(repeat_last, outputs=status)
end_btn.click(end_session, outputs=status)
return interface
def main():
"""Main function to run the Gradio interface."""
interface = create_gradio_interface()
# Launch the Gradio app
interface.launch(
share=True, # For testing - creates a public link
enable_queue=True, # Handles multiple users
server_name="0.0.0.0", # Binds to all interfaces
server_port=7860, # Default Gradio port
)
if __name__ == "__main__":
main()