Spaces:
Configuration error
Configuration error
| import fastapi | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastrtc import ReplyOnPause, Stream, AlgoOptions, SileroVadOptions, get_cloudflare_turn_credentials_async, get_cloudflare_turn_credentials | |
| from fastrtc.utils import audio_to_int16 | |
| from openai import OpenAI | |
| from elevenlabs.client import ElevenLabs | |
| from dotenv import load_dotenv | |
| from tts.audio_edge_tts import EdgeTTS | |
| from rag import document_retriever, ddgs | |
| import logging | |
| import time | |
| import platform | |
| import socket | |
| import os | |
| import numpy as np | |
| import io | |
| import wave | |
| import asyncio | |
| import librosa | |
| from pydub import AudioSegment | |
| from collections import deque | |
| import torch | |
| import torchaudio.transforms as T | |
| import concurrent.futures | |
| import threading | |
| from config.constant import HF_TOKEN | |
| import re | |
| from langchain_core.documents import Document | |
| import torchaudio | |
| # Load .env | |
| load_dotenv() | |
| logging.basicConfig(level=logging.INFO) | |
| class RTCHandler: | |
| def __init__(self, openai_client: OpenAI, whisper_stt=None, edge_tts: EdgeTTS=None): | |
| self.whisper_stt = whisper_stt | |
| self.edge_tts = edge_tts | |
| self.prompt = "" | |
| self.sys_prompt = ( | |
| "Kamu adalah customer service yang berbahasa Indonesia dengan baik sopan, santun, tapi santai pembawaannya.\n" | |
| "Kamu bisa menjelaskan sesuatu secara baik dan membimbing customer dalam menghadapi masalah yang ada!\n" | |
| "Kamu akan menjawab customer dengan media call /telepon jadi anda harus memberikan respon seperlunya saja\n" | |
| "Tidak kepanjanngan, dan sangat jelas, Tidak lebih dari 50 kata." | |
| ) | |
| self.openai_client = openai_client | |
| self.messages = [{"role": "system", "content": self.sys_prompt}] | |
| self.full_response = "" | |
| self.stream = None | |
| self.app = None | |
| self._setup_webrtc_ip() | |
| def _setup_webrtc_ip(self): | |
| if platform.system() == 'Windows': | |
| s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) | |
| try: | |
| s.connect(('8.8.8.8', 80)) | |
| local_ip = s.getsockname()[0] | |
| except Exception: | |
| local_ip = '127.0.0.1' | |
| finally: | |
| s.close() | |
| os.environ['WEBRTC_IP'] = local_ip | |
| def audio_to_bytes(self, audio_tuple, sample_rate=24000) -> io.BufferedReader: | |
| sr, audio_data = audio_tuple | |
| audio_int16 = audio_to_int16(audio_tuple) | |
| buffer = io.BytesIO() | |
| with wave.open(buffer, "wb") as wf: | |
| wf.setnchannels(1) | |
| wf.setsampwidth(2) | |
| wf.setframerate(sr) | |
| wf.writeframes(audio_int16.tobytes()) | |
| buffer.seek(0) | |
| buffer.name = "audio.wav" | |
| return buffer | |
| def echo(self, audio): | |
| try: | |
| stt_time = time.time() | |
| logging.info("Performing STT") | |
| transcription = self.openai_client.audio.transcriptions.create( | |
| model="whisper-1", | |
| file=self.audio_to_bytes(audio), | |
| language="id" | |
| ) | |
| self.prompt = transcription.text | |
| if not self.prompt: | |
| logging.info("STT returned empty string") | |
| return | |
| logging.info(f"STT response: {transcription}") | |
| logging.info(f"STT took {time.time() - stt_time} seconds") | |
| llm_time = time.time() | |
| self.full_response = "" | |
| async def stream_text_to_audio(): | |
| retrieval_result = await document_retriever.retrieve(query=self.prompt) | |
| contexts = "" | |
| search_results = [] | |
| async for result in ddgs.search(self.prompt, max_results=5): | |
| doc = Document( | |
| page_content=result, | |
| metadata={"source": "internet_search", "query": self.prompt} | |
| ) | |
| search_results.append(doc) | |
| await document_retriever.add_documents(search_results) | |
| for i, ctx in enumerate(retrieval_result.documents, 1): | |
| contexts += f"{i}. {ctx.page_content}\n" | |
| self.messages.append({ | |
| "role": "user", | |
| "content": ( | |
| f"Dari Konteks yang diberikan (jika diperlukan) :\n{contexts}\n" | |
| f"Berikan jawaban atas pertanyaan yang diberikan :\n{self.prompt}" | |
| ) | |
| }) | |
| response = self.openai_client.chat.completions.create( | |
| model="gpt-3.5-turbo", | |
| messages=self.messages, | |
| max_tokens=200, | |
| stream=True | |
| ) | |
| chunk_size = 1024 | |
| text_buffer = "" | |
| for stream_data in response: | |
| delta = stream_data.choices[0].delta.content | |
| if stream_data.choices[0].finish_reason == "stop": | |
| if text_buffer: | |
| yield text_buffer | |
| break | |
| if delta: | |
| self.full_response += delta | |
| text_buffer += delta | |
| if re.search(r'[.,?;!]', delta): | |
| try: | |
| audio_buffer_gen = await self.edge_tts.generate_audio_buffer(text_buffer) | |
| audio_buffer = audio_buffer_gen[0] | |
| audio_buffer.seek(0) | |
| audio_segment = AudioSegment.from_file(audio_buffer, format="mp3") | |
| samples = np.array(audio_segment.get_array_of_samples()).astype(np.float32) / (2 ** 15) | |
| if audio_segment.channels == 2: | |
| samples = samples.reshape((-1, 2)).mean(axis=1) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| audio_tensor = torch.from_numpy(samples).unsqueeze(0).to(device) | |
| resampler = torchaudio.transforms.Resample( | |
| orig_freq=audio_segment.frame_rate, | |
| new_freq=24000 | |
| ).to(device) | |
| resampled_tensor = resampler(audio_tensor) | |
| resampled = resampled_tensor.squeeze(0).cpu().numpy() | |
| for i in range(0, len(resampled), chunk_size): | |
| yield (24000, resampled[i:i + chunk_size]) | |
| text_buffer = "" | |
| except Exception as e: | |
| logging.error(f"TTS generation failed for chunk: {e}") | |
| continue | |
| loop = asyncio.new_event_loop() | |
| asyncio.set_event_loop(loop) | |
| try: | |
| async_gen = stream_text_to_audio() | |
| while True: | |
| try: | |
| chunk = loop.run_until_complete(async_gen.__anext__()) | |
| yield chunk | |
| except StopAsyncIteration: | |
| break | |
| finally: | |
| loop.close() | |
| self.messages.append({"role": "assistant", "content": self.full_response + " "}) | |
| logging.info(f"LLM response: {self.full_response}") | |
| logging.info(f"LLM took {time.time() - llm_time} seconds") | |
| except Exception as e: | |
| logging.error(f"Error in echo function: {e}") | |
| error_audio = np.zeros(24000, dtype=np.float32) | |
| yield (24000, error_audio) | |
| def reset_conversation(self): | |
| logging.info("Resetting chat") | |
| self.messages = [{"role": "system", "content": self.sys_prompt}] | |
| self.full_response = "" | |
| def create_stream(self): | |
| try: | |
| async def get_credentials(): | |
| return await get_cloudflare_turn_credentials_async(hf_token=HF_TOKEN) | |
| self.stream = Stream( | |
| rtc_configuration=get_credentials, | |
| server_rtc_configuration=get_cloudflare_turn_credentials(ttl=360_000), | |
| handler=ReplyOnPause( | |
| self.echo, | |
| algo_options=AlgoOptions( | |
| audio_chunk_duration=0.5, | |
| started_talking_threshold=0.1, | |
| speech_threshold=0.03 | |
| ), | |
| model_options=SileroVadOptions( | |
| threshold=0.90, | |
| min_speech_duration_ms=250, | |
| min_silence_duration_ms=2000, | |
| speech_pad_ms=400, | |
| max_speech_duration_s=15 | |
| ) | |
| ), | |
| modality="audio", | |
| mode="send-receive" | |
| ) | |
| return self.stream | |
| except Exception as e: | |
| logging.error(f"Error creating stream: {e}") | |
| raise | |
| def create_fastapi_app(self): | |
| try: | |
| self.app = fastapi.FastAPI() | |
| self.app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| if not self.stream: | |
| self.create_stream() | |
| self.stream.mount(self.app) | |
| async def reset(): | |
| try: | |
| self.reset_conversation() | |
| return {"status": "success"} | |
| except Exception as e: | |
| logging.error(f"Error in reset endpoint: {e}") | |
| return {"status": "error", "message": str(e)} | |
| async def status(): | |
| try: | |
| return { | |
| "status": "running", | |
| "messages_count": len(self.messages), | |
| "last_response": self.full_response | |
| } | |
| except Exception as e: | |
| logging.error(f"Error in status endpoint: {e}") | |
| return {"status": "error", "message": str(e)} | |
| return self.app | |
| except Exception as e: | |
| logging.error(f"Error creating FastAPI app: {e}") | |
| raise | |
| def start_server(self, host: str = "0.0.0.0", port: int = 7860): | |
| import uvicorn | |
| if not self.app: | |
| self.create_fastapi_app() | |
| logging.info(f"Starting server on {host}:{port}") | |
| try: | |
| uvicorn.run(self.app, host=host, port=port, log_level="info") | |
| except Exception as e: | |
| logging.error(f"Error starting server: {e}") | |
| raise | |
| def launch_ui(self, browser: bool = True): | |
| try: | |
| if not self.stream: | |
| self.create_stream() | |
| if not self.app: | |
| self.create_fastapi_app() | |
| logging.info("Launching RTC UI...") | |
| self.stream.ui.launch( | |
| self.app, | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| ) | |
| except Exception as e: | |
| logging.error(f"Error launching UI: {e}") | |
| raise | |
| def get_conversation_history(self): | |
| return self.messages.copy() | |
| def set_system_prompt(self, new_prompt: str): | |
| self.sys_prompt = new_prompt | |
| self.messages[0] = {"role": "system", "content": new_prompt} | |
| def get_last_response(self): | |
| return self.full_response | |