|
|
|
|
|
|
|
|
import os |
|
|
import base64 |
|
|
import json |
|
|
import asyncio |
|
|
import numpy as np |
|
|
from fastapi import FastAPI, WebSocket, WebSocketDisconnect |
|
|
from dotenv import load_dotenv |
|
|
|
|
|
from audio_utils import ulaw_to_pcm16 |
|
|
from stt_handler import transcribe_audio_chunk |
|
|
from llm_handler import get_llm_response |
|
|
from tts_handler import text_to_speech_stream |
|
|
from tool_handler import execute_tool_call |
|
|
|
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
|
|
|
@app.get("/") |
|
|
async def root(): |
|
|
""" |
|
|
A simple GET endpoint to confirm the server is running and provide info. |
|
|
This is what you see when you visit the Hugging Face Space URL in a browser. |
|
|
""" |
|
|
return {"status": "running", "message": "RentBot is active. Connect via WebSocket at the /rentbot endpoint."} |
|
|
|
|
|
|
|
|
|
|
|
SILENCE_THRESHOLD_SECONDS = 0.7 |
|
|
AUDIO_RATE = 8000 |
|
|
AUDIO_BUFFER_SIZE = int(SILENCE_THRESHOLD_SECONDS * AUDIO_RATE) |
|
|
|
|
|
|
|
|
sessions = {} |
|
|
|
|
|
|
|
|
|
|
|
@app.websocket("/rentbot") |
|
|
async def websocket_endpoint(ws: WebSocket): |
|
|
await ws.accept() |
|
|
stream_sid = None |
|
|
audio_buffer = np.array([], dtype=np.int16) |
|
|
|
|
|
try: |
|
|
async for message in ws.iter_text(): |
|
|
data = json.loads(message) |
|
|
|
|
|
if data['event'] == 'start': |
|
|
stream_sid = data['start']['streamSid'] |
|
|
sessions[stream_sid] = { |
|
|
"messages": [{"role": "system", "content": os.getenv("SYSTEM_PROMPT")}], |
|
|
"processing_task": None |
|
|
} |
|
|
print(f"New stream started: {stream_sid}") |
|
|
|
|
|
|
|
|
initial_greeting = "Hi! I'm RentBot, your leasing assistant. How can I help you today?" |
|
|
sessions[stream_sid]["messages"].append({"role": "assistant", "content": initial_greeting}) |
|
|
|
|
|
async def send_initial_greeting(): |
|
|
tts_iterator = text_to_speech_stream(iter([initial_greeting])) |
|
|
async for audio_chunk in tts_iterator: |
|
|
payload = base64.b64encode(audio_chunk).decode('utf-8') |
|
|
await ws.send_json({ |
|
|
"event": "media", |
|
|
"streamSid": stream_sid, |
|
|
"media": {"payload": payload} |
|
|
}) |
|
|
await ws.send_json({"event": "mark", "streamSid": stream_sid, "mark": {"name": "bot_turn_end"}}) |
|
|
|
|
|
asyncio.create_task(send_initial_greeting()) |
|
|
|
|
|
elif data['event'] == 'media': |
|
|
if not stream_sid: continue |
|
|
chunk_ulaw = base64.b64decode(data['media']['payload']) |
|
|
chunk_pcm = ulaw_to_pcm16(chunk_ulaw) |
|
|
audio_buffer = np.append(audio_buffer, chunk_pcm) |
|
|
|
|
|
if len(audio_buffer) >= AUDIO_BUFFER_SIZE: |
|
|
if sessions[stream_sid].get("processing_task") and not sessions[stream_sid]["processing_task"].done(): |
|
|
continue |
|
|
task = asyncio.create_task(process_user_audio(ws, stream_sid, audio_buffer)) |
|
|
sessions[stream_sid]["processing_task"] = task |
|
|
audio_buffer = np.array([], dtype=np.int16) |
|
|
|
|
|
elif data['event'] == 'mark': |
|
|
if not stream_sid: continue |
|
|
if len(audio_buffer) > 1000: |
|
|
if not (sessions[stream_sid].get("processing_task") and not sessions[stream_sid]["processing_task"].done()): |
|
|
task = asyncio.create_task(process_user_audio(ws, stream_sid, audio_buffer)) |
|
|
sessions[stream_sid]["processing_task"] = task |
|
|
audio_buffer = np.array([], dtype=np.int16) |
|
|
|
|
|
elif data['event'] == 'stop': |
|
|
print(f"Stream stopped: {stream_sid}") |
|
|
break |
|
|
|
|
|
except WebSocketDisconnect: |
|
|
print(f"WebSocket disconnected for stream {stream_sid}") |
|
|
except Exception as e: |
|
|
print(f"An error occurred in websocket_endpoint: {e}") |
|
|
finally: |
|
|
if stream_sid and stream_sid in sessions: |
|
|
if sessions[stream_sid].get("processing_task"): |
|
|
sessions[stream_sid]["processing_task"].cancel() |
|
|
del sessions[stream_sid] |
|
|
print(f"Session cleaned up for stream {stream_sid}") |
|
|
|
|
|
|
|
|
|
|
|
async def process_user_audio(ws: WebSocket, stream_sid: str, audio_chunk: np.ndarray): |
|
|
"""The main logic loop: STT -> LLM -> (Tool/TTS)""" |
|
|
print(f"[{stream_sid}] Processing audio chunk of size {len(audio_chunk)}...") |
|
|
|
|
|
|
|
|
user_text = await transcribe_audio_chunk(audio_chunk) |
|
|
if not user_text: |
|
|
print(f"[{stream_sid}] No text transcribed.") |
|
|
return |
|
|
|
|
|
print(f"[{stream_sid}] User said: {user_text}") |
|
|
sessions[stream_sid]["messages"].append({"role": "user", "content": user_text}) |
|
|
|
|
|
|
|
|
tts_queue = asyncio.Queue() |
|
|
async def llm_chunk_handler(chunk): await tts_queue.put(chunk) |
|
|
async def tts_text_iterator(): |
|
|
while True: |
|
|
chunk = await tts_queue.get() |
|
|
if chunk is None: break |
|
|
yield chunk |
|
|
|
|
|
|
|
|
llm_task = asyncio.create_task(get_llm_response(sessions[stream_sid]["messages"], llm_chunk_handler)) |
|
|
tts_task = asyncio.create_task(stream_and_send_audio(ws, stream_sid, tts_text_iterator())) |
|
|
|
|
|
|
|
|
assistant_message, tool_calls = await llm_task |
|
|
await tts_queue.put(None) |
|
|
await tts_task |
|
|
|
|
|
if assistant_message and assistant_message.get("content"): |
|
|
sessions[stream_sid]["messages"].append(assistant_message) |
|
|
|
|
|
|
|
|
if tool_calls: |
|
|
sessions[stream_sid]["messages"].append(assistant_message) |
|
|
|
|
|
for tool_call_data in tool_calls: |
|
|
tool_call = type('ToolCall', (), { |
|
|
'id': tool_call_data.get('id'), |
|
|
'function': type('Function', (), tool_call_data.get('function')) |
|
|
})() |
|
|
|
|
|
print(f"[{stream_sid}] Executing tool: {tool_call.function.name}") |
|
|
tool_result_message = execute_tool_call(tool_call) |
|
|
sessions[stream_sid]["messages"].append(tool_result_message) |
|
|
|
|
|
|
|
|
final_tts_queue = asyncio.Queue() |
|
|
async def final_llm_chunk_handler(chunk): await final_tts_queue.put(chunk) |
|
|
async def final_tts_iterator(): |
|
|
while True: |
|
|
chunk = await final_tts_queue.get() |
|
|
if chunk is None: break |
|
|
yield chunk |
|
|
|
|
|
final_llm_task = asyncio.create_task(get_llm_response(sessions[stream_sid]["messages"], final_llm_chunk_handler)) |
|
|
final_tts_task = asyncio.create_task(stream_and_send_audio(ws, stream_sid, final_tts_iterator())) |
|
|
|
|
|
final_assistant_message, _ = await final_llm_task |
|
|
await final_tts_queue.put(None) |
|
|
await final_tts_task |
|
|
|
|
|
if final_assistant_message: |
|
|
sessions[stream_sid]["messages"].append(final_assistant_message) |
|
|
|
|
|
|
|
|
async def stream_and_send_audio(ws: WebSocket, stream_sid: str, text_iterator): |
|
|
"""Stream text to TTS and send the resulting audio back over the WebSocket.""" |
|
|
async for audio_chunk in text_to_speech_stream(text_iterator): |
|
|
if audio_chunk: |
|
|
payload = base64.b64encode(audio_chunk).decode('utf-8') |
|
|
await ws.send_json({ |
|
|
"event": "media", |
|
|
"streamSid": stream_sid, |
|
|
"media": {"payload": payload} |
|
|
}) |
|
|
|
|
|
await ws.send_json({"event": "mark", "streamSid": stream_sid, "mark": {"name": "bot_turn_end"}}) |
|
|
print(f"[{stream_sid}] Finished sending bot's audio turn.") |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
import uvicorn |
|
|
|
|
|
port = int(os.environ.get("PORT", 7860)) |
|
|
print(f"Starting RentBot server on host 0.0.0.0 and port {port}...") |
|
|
uvicorn.run(app, host="0.0.0.0", port=port) |