Spaces:
Configuration error
Configuration error
| import fastapi | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastrtc import ReplyOnPause, Stream, AlgoOptions, SileroVadOptions, get_cloudflare_turn_credentials_async, get_cloudflare_turn_credentials | |
| from src.utils.audio_helper import audio_to_bytes, resample_audio | |
| from dotenv import load_dotenv | |
| from src.internal.rag.chat_template import get_chat_template | |
| from src.internal.tts.base_tts import TTS | |
| from src.internal.stt.base_stt import STT | |
| from src.internal.agents import Agent, AgentRequest, BankingCRUDAgent, CSAgent | |
| import logging | |
| import time | |
| import platform | |
| import socket | |
| import os | |
| import numpy as np | |
| import asyncio | |
| import asyncio | |
| from src.config.constant import HF_TOKEN | |
| import re | |
| load_dotenv() | |
| logging.basicConfig(level=logging.INFO) | |
| class RTCHandler: | |
| def __init__(self, agent : Agent , stt: STT, tts : TTS): | |
| self.agent = agent | |
| self.stt = stt | |
| self.tts = tts | |
| self.full_response = "" | |
| self.stream = None | |
| self.app = None | |
| self._setup_webrtc_ip() | |
| def _setup_webrtc_ip(self): | |
| """Setup WebRTC IP for Windows""" | |
| if platform.system() == 'Windows': | |
| s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) | |
| try: | |
| s.connect(('8.8.8.8', 80)) | |
| local_ip = s.getsockname()[0] | |
| except Exception: | |
| local_ip = '127.0.0.1' | |
| finally: | |
| s.close() | |
| os.environ['WEBRTC_IP'] = local_ip | |
| def echo(self, audio): | |
| try: | |
| chat_memory = [] | |
| stt_time = time.time() | |
| logging.info("Performing STT") | |
| transcription = self.stt.transcribe(audio_to_bytes(audio)) | |
| prompt = transcription | |
| if prompt == "": | |
| logging.info("STT returned empty string") | |
| return | |
| logging.info(f"STT response: {transcription}") | |
| logging.info(f"STT took {time.time() - stt_time} seconds") | |
| llm_time = time.time() | |
| self.full_response = "" | |
| router_agent_request = None | |
| if(type(self.agent) is BankingCRUDAgent): | |
| router_agent_request = AgentRequest( | |
| chat_memory = chat_memory, | |
| prompt_template = { | |
| "api_banking_template" : get_chat_template("api_banking"), | |
| "responder_template" : get_chat_template("responder_banking") | |
| }, | |
| question = prompt | |
| ) | |
| else: | |
| router_agent_request = AgentRequest( | |
| chat_memory = chat_memory, | |
| prompt_template = get_chat_template("customer_service"), | |
| question = prompt | |
| ) | |
| async def stream_text_to_audio(): | |
| chunk_size = 1024 | |
| text_buffer = "" | |
| async for stream_data in self.agent.get_result(router_agent_request): | |
| if stream_data["type"] == "chunk": | |
| chunk = stream_data["data"]["chunk"] | |
| self.full_response += chunk | |
| text_buffer += chunk | |
| if re.search(r'[.,?;!]', chunk): | |
| try: | |
| audio_buffer_gen = await self.tts.generate_audio_buffer(text_buffer) | |
| audio_buffer = audio_buffer_gen[0] | |
| resampled = resample_audio(audio_buffer) | |
| for i in range(0, len(resampled), chunk_size): | |
| yield (24000, resampled[i:i + chunk_size]) | |
| no_buffer = 0 | |
| text_buffer = "" | |
| except Exception as e: | |
| logging.error(f"TTS generation failed for chunk: {e}") | |
| continue | |
| elif stream_data["type"] == "metadata": | |
| setup_time = stream_data['data']['setup_time'] | |
| print(f"\nSetup completed in {setup_time:.2f}s") | |
| elif stream_data["type"] == "complete": | |
| total_time = stream_data['data']['total_time'] | |
| print(f"\nTotal time: {total_time:.2f}s") | |
| break | |
| loop = asyncio.new_event_loop() | |
| asyncio.set_event_loop(loop) | |
| try: | |
| async_gen = stream_text_to_audio() | |
| while True: | |
| try: | |
| chunk = loop.run_until_complete(async_gen.__anext__()) | |
| yield chunk | |
| except StopAsyncIteration: | |
| break | |
| finally: | |
| loop.close() | |
| if(len(chat_memory) >= 15): | |
| chat_memory = [] | |
| chat_memory.append({"role": "assistant", "content": self.full_response + " "}) | |
| logging.info(f"LLM response: {self.full_response}") | |
| logging.info(f"LLM took {time.time() - llm_time} seconds") | |
| except Exception as e: | |
| logging.error(f"Error in echo function: {e}") | |
| error_audio = np.zeros(24000, dtype=np.float32) | |
| yield (24000, error_audio) | |
| def reset_conversation(self): | |
| logging.info("Resetting chat") | |
| self.messages = [{"role": "system", "content": self.sys_prompt}] | |
| self.full_response = "" | |
| def create_stream(self): | |
| try: | |
| async def get_credentials(): | |
| return await get_cloudflare_turn_credentials_async(hf_token=HF_TOKEN) | |
| self.stream = Stream( | |
| rtc_configuration=get_credentials, | |
| server_rtc_configuration=get_cloudflare_turn_credentials(ttl=360_000), | |
| handler = ReplyOnPause( | |
| self.echo, | |
| algo_options=AlgoOptions( | |
| audio_chunk_duration=0.5, | |
| started_talking_threshold=0.1, | |
| speech_threshold=0.03 | |
| ), | |
| model_options=SileroVadOptions( | |
| threshold=0.90, | |
| min_speech_duration_ms=250, | |
| min_silence_duration_ms=2000, | |
| speech_pad_ms=400, | |
| max_speech_duration_s=15 | |
| ) | |
| ), | |
| modality="audio", | |
| mode="send-receive", | |
| ui_args={"title": "Sakura A.I Customer Service"}, | |
| ) | |
| return self.stream | |
| except Exception as e: | |
| logging.error(f"Error creating stream: {e}") | |
| raise | |
| def create_fastapi_app(self): | |
| try: | |
| self.app = fastapi.FastAPI() | |
| self.app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| if not self.stream: | |
| self.create_stream() | |
| self.stream.mount(self.app) | |
| async def reset(): | |
| try: | |
| self.reset_conversation() | |
| return {"status": "success"} | |
| except Exception as e: | |
| logging.error(f"Error in reset endpoint: {e}") | |
| return {"status": "error", "message": str(e)} | |
| async def status(): | |
| try: | |
| return { | |
| "status": "running", | |
| "messages_count": len(self.messages), | |
| "last_response": self.full_response | |
| } | |
| except Exception as e: | |
| logging.error(f"Error in status endpoint: {e}") | |
| return {"status": "error", "message": str(e)} | |
| return self.app | |
| except Exception as e: | |
| logging.error(f"Error creating FastAPI app: {e}") | |
| raise | |
| def start_server(self, host: str = "0.0.0.0", port: int = 7862): | |
| import uvicorn | |
| if not self.app: | |
| self.create_fastapi_app() | |
| logging.info(f"Starting server on {host}:{port}") | |
| try: | |
| uvicorn.run(self.app, host=host, port=port, log_level="info") | |
| except Exception as e: | |
| logging.error(f"Error starting server: {e}") | |
| raise | |
| def launch_ui(self, browser: bool = True, port = 7860): | |
| try: | |
| if not self.stream: | |
| self.create_stream() | |
| if not self.app: | |
| self.create_fastapi_app() | |
| logging.info("Launching RTC UI...") | |
| self.stream.ui.launch(self.app, | |
| server_name="0.0.0.0", | |
| server_port=port, | |
| share = True | |
| ) | |
| except Exception as e: | |
| logging.error(f"Error launching UI: {e}") | |
| raise | |
| def get_conversation_history(self): | |
| return self.messages.copy() | |
| def get_last_response(self): | |
| return self.full_response | |