lifedebugger's picture
Deploy files from GitHub repository
7ae0310
import fastapi
from fastapi.middleware.cors import CORSMiddleware
from fastrtc import ReplyOnPause, Stream, AlgoOptions, SileroVadOptions, get_cloudflare_turn_credentials_async, get_cloudflare_turn_credentials
from src.utils.audio_helper import audio_to_bytes, resample_audio
from dotenv import load_dotenv
from src.internal.rag.chat_template import get_chat_template
from src.internal.tts.base_tts import TTS
from src.internal.stt.base_stt import STT
from src.internal.agents import Agent, AgentRequest
import logging
import time
import platform
import socket
import os
import numpy as np
import asyncio
import asyncio
from src.config.constant import HF_TOKEN
import re
load_dotenv()
logging.basicConfig(level=logging.INFO)
class RTCHandler:
def __init__(self, agent : Agent , stt: STT, tts : TTS):
self.agent = agent
self.stt = stt
self.tts = tts
self.full_response = ""
self.stream = None
self.app = None
self._setup_webrtc_ip()
def _setup_webrtc_ip(self):
"""Setup WebRTC IP for Windows"""
if platform.system() == 'Windows':
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
try:
s.connect(('8.8.8.8', 80))
local_ip = s.getsockname()[0]
except Exception:
local_ip = '127.0.0.1'
finally:
s.close()
os.environ['WEBRTC_IP'] = local_ip
def echo(self, audio):
try:
chat_memory = []
stt_time = time.time()
logging.info("Performing STT")
transcription = self.stt.transcribe(audio_to_bytes(audio))
prompt = transcription
if prompt == "":
logging.info("STT returned empty string")
return
logging.info(f"STT response: {transcription}")
logging.info(f"STT took {time.time() - stt_time} seconds")
llm_time = time.time()
self.full_response = ""
router_agent_request = AgentRequest(
chat_memory = chat_memory,
prompt_template = get_chat_template("customer_service"),
question = prompt
)
async def stream_text_to_audio():
chunk_size = 1024
text_buffer = ""
async for stream_data in self.agent.get_result(router_agent_request):
if stream_data["type"] == "chunk":
chunk = stream_data["data"]["chunk"]
self.full_response += chunk
text_buffer += chunk
if re.search(r'[.,?;!]', chunk):
try:
audio_buffer_gen = await self.tts.generate_audio_buffer(text_buffer)
audio_buffer = audio_buffer_gen[0]
resampled = resample_audio(audio_buffer)
for i in range(0, len(resampled), chunk_size):
yield (24000, resampled[i:i + chunk_size])
no_buffer = 0
text_buffer = ""
except Exception as e:
logging.error(f"TTS generation failed for chunk: {e}")
continue
elif stream_data["type"] == "metadata":
setup_time = stream_data['data']['setup_time']
print(f"\nSetup completed in {setup_time:.2f}s")
elif stream_data["type"] == "complete":
total_time = stream_data['data']['total_time']
print(f"\nTotal time: {total_time:.2f}s")
break
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
async_gen = stream_text_to_audio()
while True:
try:
chunk = loop.run_until_complete(async_gen.__anext__())
yield chunk
except StopAsyncIteration:
break
finally:
loop.close()
if(len(chat_memory) >= 15):
chat_memory = []
chat_memory.append({"role": "assistant", "content": self.full_response + " "})
logging.info(f"LLM response: {self.full_response}")
logging.info(f"LLM took {time.time() - llm_time} seconds")
except Exception as e:
logging.error(f"Error in echo function: {e}")
error_audio = np.zeros(24000, dtype=np.float32)
yield (24000, error_audio)
def reset_conversation(self):
logging.info("Resetting chat")
self.messages = [{"role": "system", "content": self.sys_prompt}]
self.full_response = ""
def create_stream(self):
try:
async def get_credentials():
return await get_cloudflare_turn_credentials_async(hf_token=HF_TOKEN)
self.stream = Stream(
rtc_configuration=get_credentials,
server_rtc_configuration=get_cloudflare_turn_credentials(ttl=360_000),
handler = ReplyOnPause(
self.echo,
algo_options=AlgoOptions(
audio_chunk_duration=0.5,
started_talking_threshold=0.1,
speech_threshold=0.03
),
model_options=SileroVadOptions(
threshold=0.90,
min_speech_duration_ms=250,
min_silence_duration_ms=2000,
speech_pad_ms=400,
max_speech_duration_s=15
)
),
modality="audio",
mode="send-receive",
ui_args={"title": "Sakura A.I Customer Service"},
)
return self.stream
except Exception as e:
logging.error(f"Error creating stream: {e}")
raise
def create_fastapi_app(self):
try:
self.app = fastapi.FastAPI()
self.app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
if not self.stream:
self.create_stream()
self.stream.mount(self.app)
@self.app.get("/reset")
async def reset():
try:
self.reset_conversation()
return {"status": "success"}
except Exception as e:
logging.error(f"Error in reset endpoint: {e}")
return {"status": "error", "message": str(e)}
@self.app.get("/status")
async def status():
try:
return {
"status": "running",
"messages_count": len(self.messages),
"last_response": self.full_response
}
except Exception as e:
logging.error(f"Error in status endpoint: {e}")
return {"status": "error", "message": str(e)}
return self.app
except Exception as e:
logging.error(f"Error creating FastAPI app: {e}")
raise
def start_server(self, host: str = "0.0.0.0", port: int = 7862):
import uvicorn
if not self.app:
self.create_fastapi_app()
logging.info(f"Starting server on {host}:{port}")
try:
uvicorn.run(self.app, host=host, port=port, log_level="info")
except Exception as e:
logging.error(f"Error starting server: {e}")
raise
def launch_ui(self, browser: bool = True, port = 7860):
try:
if not self.stream:
self.create_stream()
if not self.app:
self.create_fastapi_app()
logging.info("Launching RTC UI...")
self.stream.ui.launch(self.app,
server_name="0.0.0.0",
server_port=port,
share = True
)
except Exception as e:
logging.error(f"Error launching UI: {e}")
raise
def get_conversation_history(self):
return self.messages.copy()
def get_last_response(self):
return self.full_response