Spaces:
Build error
Build error
Commit
·
e67faa0
1
Parent(s):
decb45c
Realtime Flow
Browse files- src/agent_session/main.py +17 -83
src/agent_session/main.py
CHANGED
|
@@ -1,11 +1,6 @@
|
|
| 1 |
#!/usr/bin/env python3
|
| 2 |
"""
|
| 3 |
Agent Session for Avurna Flow (Final Corrected Version)
|
| 4 |
-
|
| 5 |
-
This script sets up a FastAPI web server to manage a voice-based AI agent
|
| 6 |
-
that connects to a LiveKit room. The core issue of the TypeError is resolved
|
| 7 |
-
by implementing the LLMStateWrapper, which correctly handles the asynchronous
|
| 8 |
-
generator returned by the LLM's chat method.
|
| 9 |
"""
|
| 10 |
import os
|
| 11 |
import json
|
|
@@ -14,93 +9,62 @@ from fastapi import FastAPI, BackgroundTasks
|
|
| 14 |
from pydantic import BaseModel
|
| 15 |
import uvicorn
|
| 16 |
from fastapi.middleware.cors import CORSMiddleware
|
|
|
|
| 17 |
|
| 18 |
-
# Import LiveKit and plugin components
|
| 19 |
from livekit.rtc import Room
|
| 20 |
from livekit.agents import Agent, AgentSession
|
| 21 |
from livekit.agents.llm import LLM
|
| 22 |
from livekit.agents.stt.stream_adapter import StreamAdapter
|
| 23 |
from livekit.plugins.google import LLM as GoogleLLM
|
| 24 |
from livekit.plugins.groq import STT
|
| 25 |
-
from livekit.plugins.hume import TTS, VoiceByName
|
| 26 |
from livekit.plugins.silero import VAD
|
| 27 |
|
| 28 |
from src.agent_session.constants import SYSTEM_PROMPT, GREETING_INSTRUCTIONS
|
| 29 |
from src.utils import validate_env_vars
|
| 30 |
|
| 31 |
-
# --- FastAPI Application Setup ---
|
| 32 |
app = FastAPI()
|
| 33 |
-
|
| 34 |
-
# Configure CORS (Cross-Origin Resource Sharing) to allow all origins
|
| 35 |
origins = ["*"]
|
| 36 |
-
app.add_middleware(
|
| 37 |
-
|
| 38 |
-
allow_origins=origins,
|
| 39 |
-
allow_credentials=True,
|
| 40 |
-
allow_methods=["*"],
|
| 41 |
-
allow_headers=["*"]
|
| 42 |
-
)
|
| 43 |
-
|
| 44 |
-
# --- Pydantic Model for API Request ---
|
| 45 |
class JoinRoomRequest(BaseModel):
|
| 46 |
-
"""Defines the expected data structure for a /join-room request."""
|
| 47 |
room_name: str
|
| 48 |
agent_token: str
|
| 49 |
|
| 50 |
-
# --- Custom Agent Definition ---
|
| 51 |
class VoiceAssistant(Agent):
|
| 52 |
-
"""A simple voice assistant agent with a predefined system prompt."""
|
| 53 |
def __init__(self):
|
| 54 |
super().__init__(instructions=SYSTEM_PROMPT)
|
| 55 |
|
| 56 |
-
# --- Utility Function for State Publishing ---
|
| 57 |
async def send_agent_state(room: Room, state: str):
|
| 58 |
-
"""Publishes the agent's current state to the room via data channel."""
|
| 59 |
try:
|
| 60 |
-
# The message is structured as a JSON object for easy parsing by clients
|
| 61 |
msg = json.dumps({"type": "agent_state", "state": state})
|
| 62 |
await room.local_participant.publish_data(msg)
|
| 63 |
print(f"DEBUG: Sent agent state: {state}")
|
| 64 |
except Exception as e:
|
| 65 |
print(f"DEBUG: Error publishing agent state: {e}")
|
| 66 |
|
| 67 |
-
# --- THE DEFINITIVE FIX: LLM Wrapper ---
|
| 68 |
class LLMStateWrapper(LLM):
|
| 69 |
-
"""
|
| 70 |
-
Wraps an LLM instance to correctly handle its async generator `chat` method
|
| 71 |
-
and to inject agent state updates ("thinking", "listening") into the process.
|
| 72 |
-
This class solves the `TypeError: 'async_generator' object does not support
|
| 73 |
-
the asynchronous context manager protocol`.
|
| 74 |
-
"""
|
| 75 |
def __init__(self, llm: LLM, room: Room):
|
| 76 |
super().__init__()
|
| 77 |
self._llm = llm
|
| 78 |
self._room = room
|
| 79 |
|
|
|
|
|
|
|
| 80 |
async def chat(self, **kwargs):
|
| 81 |
-
|
| 82 |
-
This method is called by the AgentSession. It intercepts the call to the LLM,
|
| 83 |
-
sends the 'thinking' state, properly iterates over the LLM's async generator,
|
| 84 |
-
and then sends the 'listening' state upon completion.
|
| 85 |
-
"""
|
| 86 |
await send_agent_state(self._room, "thinking")
|
| 87 |
-
|
| 88 |
-
# Extract 'history' and pass all arguments along to the underlying LLM.
|
| 89 |
-
# This makes the wrapper resilient to future library updates.
|
| 90 |
-
history = kwargs.pop('history', [])
|
| 91 |
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
|
|
|
|
|
|
|
|
|
| 97 |
|
| 98 |
-
# --- Main Agent Session Logic ---
|
| 99 |
async def run_agent_session(room_name: str, agent_token: str):
|
| 100 |
-
"""
|
| 101 |
-
This function contains the main logic for connecting the agent to a
|
| 102 |
-
LiveKit room and running the session.
|
| 103 |
-
"""
|
| 104 |
livekit_url = os.getenv("LIVEKIT_URL")
|
| 105 |
room = Room()
|
| 106 |
|
|
@@ -109,35 +73,23 @@ async def run_agent_session(room_name: str, agent_token: str):
|
|
| 109 |
await room.connect(livekit_url, agent_token)
|
| 110 |
print("DEBUG: 2. Connection successful.")
|
| 111 |
|
| 112 |
-
# Agent is ready to receive input
|
| 113 |
await send_agent_state(room, "listening")
|
| 114 |
|
| 115 |
print("DEBUG: 3. Initializing plugins...")
|
| 116 |
-
|
| 117 |
-
# Wrap the Google LLM with our custom state-aware wrapper
|
| 118 |
llm_wrapper = LLMStateWrapper(llm=GoogleLLM(model="gemini-1.5-flash"), room=room)
|
| 119 |
-
|
| 120 |
-
# Configure Voice Activity Detection (VAD)
|
| 121 |
vad = VAD.load(min_speech_duration=0.1)
|
| 122 |
-
|
| 123 |
-
# Configure Speech-to-Text (STT) with the VAD adapter
|
| 124 |
stt = StreamAdapter(stt=STT(model="whisper-large-v3-turbo"), vad=vad)
|
| 125 |
-
|
| 126 |
-
# Configure Text-to-Speech (TTS)
|
| 127 |
tts = TTS(voice=VoiceByName(name="Tiktok Fashion Influencer"), instant_mode=True)
|
| 128 |
-
|
| 129 |
print("DEBUG: 4. Plugins initialized.")
|
| 130 |
|
| 131 |
print("DEBUG: 5. Creating AgentSession...")
|
| 132 |
session = AgentSession(vad=vad, stt=stt, llm=llm_wrapper, tts=tts)
|
| 133 |
print("DEBUG: 6. AgentSession created.")
|
| 134 |
|
| 135 |
-
print("DEBUG: 7. Starting session
|
| 136 |
-
# The .start() method begins the main processing loop for the agent
|
| 137 |
await session.start(agent=VoiceAssistant(), room=room)
|
| 138 |
|
| 139 |
print("DEBUG: 8. Session started. Generating initial greeting...")
|
| 140 |
-
# Proactively generate a greeting to start the conversation
|
| 141 |
await send_agent_state(room, "speaking")
|
| 142 |
await session.generate_reply(instructions=GREETING_INSTRUCTIONS)
|
| 143 |
|
|
@@ -147,37 +99,19 @@ async def run_agent_session(room_name: str, agent_token: str):
|
|
| 147 |
print(f"FATAL ERROR in agent session: {e}")
|
| 148 |
print(traceback.format_exc())
|
| 149 |
finally:
|
| 150 |
-
# Ensure cleanup and disconnection on exit or error
|
| 151 |
print(f"DEBUG: Agent session for room {room_name} is ending. Cleaning up.")
|
| 152 |
await room.disconnect()
|
| 153 |
|
| 154 |
-
# --- FastAPI API Endpoints ---
|
| 155 |
@app.post("/join-room")
|
| 156 |
async def join_room(req: JoinRoomRequest, background_tasks: BackgroundTasks):
|
| 157 |
-
"""
|
| 158 |
-
API endpoint to trigger an agent to join a room.
|
| 159 |
-
It runs the agent session as a background task to not block the HTTP response.
|
| 160 |
-
"""
|
| 161 |
print(f"DEBUG: Received POST request to /join-room for: {req.room_name}")
|
| 162 |
background_tasks.add_task(run_agent_session, req.room_name, req.agent_token)
|
| 163 |
return {"status": "agent_triggered"}
|
| 164 |
|
| 165 |
@app.get("/")
|
| 166 |
async def root():
|
| 167 |
-
"""A simple health check endpoint."""
|
| 168 |
return {"status": "avurna_agent_server_online"}
|
| 169 |
|
| 170 |
-
# --- Main Execution Block ---
|
| 171 |
if __name__ == "__main__":
|
| 172 |
-
|
| 173 |
-
validate_env_vars([
|
| 174 |
-
"HUME_API_KEY",
|
| 175 |
-
"LIVEKIT_URL",
|
| 176 |
-
"LIVEKIT_API_KEY",
|
| 177 |
-
"LIVEKIT_API_SECRET",
|
| 178 |
-
"GROQ_API_KEY",
|
| 179 |
-
"GOOGLE_API_KEY"
|
| 180 |
-
])
|
| 181 |
-
|
| 182 |
-
# Run the FastAPI server using Uvicorn
|
| 183 |
uvicorn.run(app, host="0.0.0.0", port=7860)
|
|
|
|
| 1 |
#!/usr/bin/env python3
|
| 2 |
"""
|
| 3 |
Agent Session for Avurna Flow (Final Corrected Version)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
"""
|
| 5 |
import os
|
| 6 |
import json
|
|
|
|
| 9 |
from pydantic import BaseModel
|
| 10 |
import uvicorn
|
| 11 |
from fastapi.middleware.cors import CORSMiddleware
|
| 12 |
+
from contextlib import asynccontextmanager # --- KEY: Import the required decorator ---
|
| 13 |
|
|
|
|
| 14 |
from livekit.rtc import Room
|
| 15 |
from livekit.agents import Agent, AgentSession
|
| 16 |
from livekit.agents.llm import LLM
|
| 17 |
from livekit.agents.stt.stream_adapter import StreamAdapter
|
| 18 |
from livekit.plugins.google import LLM as GoogleLLM
|
| 19 |
from livekit.plugins.groq import STT
|
| 20 |
+
from livekit.plugins.hume import TTS, VoiceByName, VoiceProvider
|
| 21 |
from livekit.plugins.silero import VAD
|
| 22 |
|
| 23 |
from src.agent_session.constants import SYSTEM_PROMPT, GREETING_INSTRUCTIONS
|
| 24 |
from src.utils import validate_env_vars
|
| 25 |
|
|
|
|
| 26 |
app = FastAPI()
|
|
|
|
|
|
|
| 27 |
origins = ["*"]
|
| 28 |
+
app.add_middleware(CORSMiddleware, allow_origins=origins, allow_credentials=True, allow_methods=["*"], allow_headers=["*"])
|
| 29 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
class JoinRoomRequest(BaseModel):
|
|
|
|
| 31 |
room_name: str
|
| 32 |
agent_token: str
|
| 33 |
|
|
|
|
| 34 |
class VoiceAssistant(Agent):
|
|
|
|
| 35 |
def __init__(self):
|
| 36 |
super().__init__(instructions=SYSTEM_PROMPT)
|
| 37 |
|
|
|
|
| 38 |
async def send_agent_state(room: Room, state: str):
|
|
|
|
| 39 |
try:
|
|
|
|
| 40 |
msg = json.dumps({"type": "agent_state", "state": state})
|
| 41 |
await room.local_participant.publish_data(msg)
|
| 42 |
print(f"DEBUG: Sent agent state: {state}")
|
| 43 |
except Exception as e:
|
| 44 |
print(f"DEBUG: Error publishing agent state: {e}")
|
| 45 |
|
|
|
|
| 46 |
class LLMStateWrapper(LLM):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
def __init__(self, llm: LLM, room: Room):
|
| 48 |
super().__init__()
|
| 49 |
self._llm = llm
|
| 50 |
self._room = room
|
| 51 |
|
| 52 |
+
# --- THE DEFINITIVE FIX ---
|
| 53 |
+
@asynccontextmanager # 1. Decorate the method to make it a valid context manager
|
| 54 |
async def chat(self, **kwargs):
|
| 55 |
+
# 2. This code runs when the context is entered
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
await send_agent_state(self._room, "thinking")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
|
| 58 |
+
try:
|
| 59 |
+
# 3. Enter the original LLM's context and get its stream
|
| 60 |
+
async with self._llm.chat(**kwargs) as stream:
|
| 61 |
+
# 4. Yield the stream to the AgentSession
|
| 62 |
+
yield stream
|
| 63 |
+
finally:
|
| 64 |
+
# 5. This code runs when the context is exited (after TTS is done)
|
| 65 |
+
await send_agent_state(self._room, "listening")
|
| 66 |
|
|
|
|
| 67 |
async def run_agent_session(room_name: str, agent_token: str):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
livekit_url = os.getenv("LIVEKIT_URL")
|
| 69 |
room = Room()
|
| 70 |
|
|
|
|
| 73 |
await room.connect(livekit_url, agent_token)
|
| 74 |
print("DEBUG: 2. Connection successful.")
|
| 75 |
|
|
|
|
| 76 |
await send_agent_state(room, "listening")
|
| 77 |
|
| 78 |
print("DEBUG: 3. Initializing plugins...")
|
|
|
|
|
|
|
| 79 |
llm_wrapper = LLMStateWrapper(llm=GoogleLLM(model="gemini-1.5-flash"), room=room)
|
|
|
|
|
|
|
| 80 |
vad = VAD.load(min_speech_duration=0.1)
|
|
|
|
|
|
|
| 81 |
stt = StreamAdapter(stt=STT(model="whisper-large-v3-turbo"), vad=vad)
|
|
|
|
|
|
|
| 82 |
tts = TTS(voice=VoiceByName(name="Tiktok Fashion Influencer"), instant_mode=True)
|
|
|
|
| 83 |
print("DEBUG: 4. Plugins initialized.")
|
| 84 |
|
| 85 |
print("DEBUG: 5. Creating AgentSession...")
|
| 86 |
session = AgentSession(vad=vad, stt=stt, llm=llm_wrapper, tts=tts)
|
| 87 |
print("DEBUG: 6. AgentSession created.")
|
| 88 |
|
| 89 |
+
print("DEBUG: 7. Starting session...")
|
|
|
|
| 90 |
await session.start(agent=VoiceAssistant(), room=room)
|
| 91 |
|
| 92 |
print("DEBUG: 8. Session started. Generating initial greeting...")
|
|
|
|
| 93 |
await send_agent_state(room, "speaking")
|
| 94 |
await session.generate_reply(instructions=GREETING_INSTRUCTIONS)
|
| 95 |
|
|
|
|
| 99 |
print(f"FATAL ERROR in agent session: {e}")
|
| 100 |
print(traceback.format_exc())
|
| 101 |
finally:
|
|
|
|
| 102 |
print(f"DEBUG: Agent session for room {room_name} is ending. Cleaning up.")
|
| 103 |
await room.disconnect()
|
| 104 |
|
|
|
|
| 105 |
@app.post("/join-room")
|
| 106 |
async def join_room(req: JoinRoomRequest, background_tasks: BackgroundTasks):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
print(f"DEBUG: Received POST request to /join-room for: {req.room_name}")
|
| 108 |
background_tasks.add_task(run_agent_session, req.room_name, req.agent_token)
|
| 109 |
return {"status": "agent_triggered"}
|
| 110 |
|
| 111 |
@app.get("/")
|
| 112 |
async def root():
|
|
|
|
| 113 |
return {"status": "avurna_agent_server_online"}
|
| 114 |
|
|
|
|
| 115 |
if __name__ == "__main__":
|
| 116 |
+
validate_env_vars(["HUME_API_KEY", "LIVEKIT_URL", "LIVEKIT_API_KEY", "LIVEKIT_API_SECRET", "GROQ_API_KEY", "GOOGLE_API_KEY"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
uvicorn.run(app, host="0.0.0.0", port=7860)
|