import fastapi from fastapi.middleware.cors import CORSMiddleware from fastrtc import ReplyOnPause, Stream, AlgoOptions, SileroVadOptions, get_cloudflare_turn_credentials_async, get_cloudflare_turn_credentials from src.utils.audio_helper import audio_to_bytes, resample_audio from dotenv import load_dotenv from src.internal.rag.chat_template import get_chat_template from src.internal.tts.base_tts import TTS from src.internal.stt.base_stt import STT from src.internal.agents import Agent, AgentRequest import logging import time import platform import socket import os import numpy as np import asyncio import asyncio from src.config.constant import HF_TOKEN import re load_dotenv() logging.basicConfig(level=logging.INFO) class RTCHandler: def __init__(self, agent : Agent , stt: STT, tts : TTS): self.agent = agent self.stt = stt self.tts = tts self.full_response = "" self.stream = None self.app = None self._setup_webrtc_ip() def _setup_webrtc_ip(self): """Setup WebRTC IP for Windows""" if platform.system() == 'Windows': s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) try: s.connect(('8.8.8.8', 80)) local_ip = s.getsockname()[0] except Exception: local_ip = '127.0.0.1' finally: s.close() os.environ['WEBRTC_IP'] = local_ip def echo(self, audio): try: chat_memory = [] stt_time = time.time() logging.info("Performing STT") transcription = self.stt.transcribe(audio_to_bytes(audio)) prompt = transcription if prompt == "": logging.info("STT returned empty string") return logging.info(f"STT response: {transcription}") logging.info(f"STT took {time.time() - stt_time} seconds") llm_time = time.time() self.full_response = "" router_agent_request = AgentRequest( chat_memory = chat_memory, prompt_template = get_chat_template("customer_service"), question = prompt ) async def stream_text_to_audio(): chunk_size = 1024 text_buffer = "" async for stream_data in self.agent.get_result(router_agent_request): if stream_data["type"] == "chunk": chunk = stream_data["data"]["chunk"] self.full_response += chunk text_buffer += chunk if re.search(r'[.,?;!]', chunk): try: audio_buffer_gen = await self.tts.generate_audio_buffer(text_buffer) audio_buffer = audio_buffer_gen[0] resampled = resample_audio(audio_buffer) for i in range(0, len(resampled), chunk_size): yield (24000, resampled[i:i + chunk_size]) no_buffer = 0 text_buffer = "" except Exception as e: logging.error(f"TTS generation failed for chunk: {e}") continue elif stream_data["type"] == "metadata": setup_time = stream_data['data']['setup_time'] print(f"\nSetup completed in {setup_time:.2f}s") elif stream_data["type"] == "complete": total_time = stream_data['data']['total_time'] print(f"\nTotal time: {total_time:.2f}s") break loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) try: async_gen = stream_text_to_audio() while True: try: chunk = loop.run_until_complete(async_gen.__anext__()) yield chunk except StopAsyncIteration: break finally: loop.close() if(len(chat_memory) >= 15): chat_memory = [] chat_memory.append({"role": "assistant", "content": self.full_response + " "}) logging.info(f"LLM response: {self.full_response}") logging.info(f"LLM took {time.time() - llm_time} seconds") except Exception as e: logging.error(f"Error in echo function: {e}") error_audio = np.zeros(24000, dtype=np.float32) yield (24000, error_audio) def reset_conversation(self): logging.info("Resetting chat") self.messages = [{"role": "system", "content": self.sys_prompt}] self.full_response = "" def create_stream(self): try: async def get_credentials(): return await get_cloudflare_turn_credentials_async(hf_token=HF_TOKEN) self.stream = Stream( rtc_configuration=get_credentials, server_rtc_configuration=get_cloudflare_turn_credentials(ttl=360_000), handler = ReplyOnPause( self.echo, algo_options=AlgoOptions( audio_chunk_duration=0.5, started_talking_threshold=0.1, speech_threshold=0.03 ), model_options=SileroVadOptions( threshold=0.90, min_speech_duration_ms=250, min_silence_duration_ms=2000, speech_pad_ms=400, max_speech_duration_s=15 ) ), modality="audio", mode="send-receive", ui_args={"title": "Sakura A.I Customer Service"}, ) return self.stream except Exception as e: logging.error(f"Error creating stream: {e}") raise def create_fastapi_app(self): try: self.app = fastapi.FastAPI() self.app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) if not self.stream: self.create_stream() self.stream.mount(self.app) @self.app.get("/reset") async def reset(): try: self.reset_conversation() return {"status": "success"} except Exception as e: logging.error(f"Error in reset endpoint: {e}") return {"status": "error", "message": str(e)} @self.app.get("/status") async def status(): try: return { "status": "running", "messages_count": len(self.messages), "last_response": self.full_response } except Exception as e: logging.error(f"Error in status endpoint: {e}") return {"status": "error", "message": str(e)} return self.app except Exception as e: logging.error(f"Error creating FastAPI app: {e}") raise def start_server(self, host: str = "0.0.0.0", port: int = 7862): import uvicorn if not self.app: self.create_fastapi_app() logging.info(f"Starting server on {host}:{port}") try: uvicorn.run(self.app, host=host, port=port, log_level="info") except Exception as e: logging.error(f"Error starting server: {e}") raise def launch_ui(self, browser: bool = True, port = 7860): try: if not self.stream: self.create_stream() if not self.app: self.create_fastapi_app() logging.info("Launching RTC UI...") self.stream.ui.launch(self.app, server_name="0.0.0.0", server_port=port, share = True ) except Exception as e: logging.error(f"Error launching UI: {e}") raise def get_conversation_history(self): return self.messages.copy() def get_last_response(self): return self.full_response