| | import asyncio |
| | import base64 |
| | import json |
| | import os |
| | import pathlib |
| | from typing import AsyncGenerator, Literal, List |
| |
|
| | import numpy as np |
| | from dotenv import load_dotenv |
| | from fastapi import FastAPI |
| | from fastapi.responses import HTMLResponse |
| | from fastrtc import AsyncStreamHandler, Stream, wait_for_item |
| | from pydantic import BaseModel |
| | import uvicorn |
| |
|
| | |
| | from gradio.utils import get_space |
| |
|
| | |
| | import PyPDF2 |
| | import docx |
| | import faiss |
| | from sentence_transformers import SentenceTransformer |
| | from transformers import pipeline |
| |
|
| | |
| | import whisper |
| | from gtts import gTTS |
| | from pydub import AudioSegment |
| | import io |
| |
|
| | |
| | load_dotenv() |
| | current_dir = pathlib.Path(__file__).parent |
| |
|
| | |
| | |
| | |
| |
|
| | DOCS_FOLDER = current_dir / "docs" |
| |
|
| | def extract_text_from_pdf(file_path: pathlib.Path) -> str: |
| | text = "" |
| | with open(file_path, "rb") as f: |
| | reader = PyPDF2.PdfReader(f) |
| | for page in reader.pages: |
| | page_text = page.extract_text() |
| | if page_text: |
| | text += page_text + "\n" |
| | return text |
| |
|
| | def extract_text_from_docx(file_path: pathlib.Path) -> str: |
| | doc = docx.Document(file_path) |
| | return "\n".join([para.text for para in doc.paragraphs]) |
| |
|
| | def extract_text_from_txt(file_path: pathlib.Path) -> str: |
| | with open(file_path, "r", encoding="utf-8") as f: |
| | return f.read() |
| |
|
| | def load_documents(folder: pathlib.Path) -> List[str]: |
| | documents = [] |
| | for file_path in folder.glob("*"): |
| | if file_path.suffix.lower() == ".pdf": |
| | documents.append(extract_text_from_pdf(file_path)) |
| | elif file_path.suffix.lower() in [".docx", ".doc"]: |
| | documents.append(extract_text_from_docx(file_path)) |
| | elif file_path.suffix.lower() == ".txt": |
| | documents.append(extract_text_from_txt(file_path)) |
| | return documents |
| |
|
| | def split_text(text: str, max_length: int = 500, overlap: int = 100) -> List[str]: |
| | chunks = [] |
| | start = 0 |
| | while start < len(text): |
| | end = start + max_length |
| | chunks.append(text[start:end]) |
| | start += max_length - overlap |
| | return chunks |
| |
|
| | documents = load_documents(DOCS_FOLDER) |
| | all_chunks = [] |
| | for doc in documents: |
| | all_chunks.extend(split_text(doc)) |
| |
|
| | embedding_model = SentenceTransformer("all-MiniLM-L6-v2") |
| | chunk_embeddings = embedding_model.encode(all_chunks) |
| | embedding_dim = chunk_embeddings.shape[1] |
| | faiss_index = faiss.IndexFlatL2(embedding_dim) |
| | faiss_index.add(np.array(chunk_embeddings)) |
| |
|
| | generator = pipeline("text-generation", model="gpt2", max_length=256) |
| |
|
| | def retrieve_context(query: str, k: int = 5) -> List[str]: |
| | query_embedding = embedding_model.encode([query]) |
| | distances, indices = faiss_index.search(np.array(query_embedding), k) |
| | return [all_chunks[idx] for idx in indices[0] if idx < len(all_chunks)] |
| |
|
| | def generate_answer(query: str) -> str: |
| | context_chunks = retrieve_context(query) |
| | context = "\n".join(context_chunks) |
| | prompt = ( |
| | f"You are a customer support agent. Use the following context to answer the question.\n\n" |
| | f"Context:\n{context}\n\n" |
| | f"Question: {query}\n\n" |
| | f"Answer:" |
| | ) |
| | response = generator(prompt, max_new_tokens=100, do_sample=True, temperature=0.7) |
| | generated_text = response[0]["generated_text"] |
| | |
| | if "Answer:" in generated_text: |
| | answer = generated_text.split("Answer:", 1)[1].strip() |
| | else: |
| | answer = generated_text.strip() |
| | return answer |
| |
|
| | |
| | |
| | |
| |
|
| | stt_model = whisper.load_model("base", device="cpu") |
| |
|
| | def speech_to_text(audio_array: np.ndarray, sample_rate: int = 16000) -> str: |
| | audio_float = audio_array.astype(np.float32) / 32768.0 |
| | result = stt_model.transcribe(audio_float, fp16=False) |
| | return result["text"] |
| |
|
| | def text_to_speech(text: str, lang="en", target_sample_rate: int = 24000) -> np.ndarray: |
| | tts = gTTS(text, lang=lang) |
| | mp3_fp = io.BytesIO() |
| | tts.write_to_fp(mp3_fp) |
| | mp3_fp.seek(0) |
| | audio = AudioSegment.from_file(mp3_fp, format="mp3") |
| | audio = audio.set_frame_rate(target_sample_rate).set_channels(1) |
| | return np.array(audio.get_array_of_samples(), dtype=np.int16) |
| |
|
| | |
| | |
| | |
| |
|
| | class RAGVoiceHandler(AsyncStreamHandler): |
| | def __init__( |
| | self, |
| | expected_layout: Literal["mono"] = "mono", |
| | output_sample_rate: int = 24000, |
| | output_frame_size: int = 480, |
| | ) -> None: |
| | super().__init__( |
| | expected_layout, |
| | output_sample_rate, |
| | output_frame_size, |
| | input_sample_rate=16000, |
| | ) |
| | self.input_queue: asyncio.Queue = asyncio.Queue() |
| | self.output_queue: asyncio.Queue = asyncio.Queue() |
| | self.quit: asyncio.Event = asyncio.Event() |
| | self.input_buffer = bytearray() |
| | self.last_input_time = asyncio.get_event_loop().time() |
| |
|
| | def copy(self) -> "RAGVoiceHandler": |
| | return RAGVoiceHandler( |
| | expected_layout="mono", |
| | output_sample_rate=self.output_sample_rate, |
| | output_frame_size=self.output_frame_size, |
| | ) |
| |
|
| | async def stream(self) -> AsyncGenerator[bytes, None]: |
| | while not self.quit.is_set(): |
| | try: |
| | audio_data = await asyncio.wait_for(self.input_queue.get(), timeout=0.5) |
| | self.input_buffer.extend(audio_data) |
| | self.last_input_time = asyncio.get_event_loop().time() |
| | except asyncio.TimeoutError: |
| | if self.input_buffer: |
| | audio_array = np.frombuffer(self.input_buffer, dtype=np.int16) |
| | self.input_buffer = bytearray() |
| | query_text = speech_to_text(audio_array, sample_rate=self.input_sample_rate) |
| | if query_text.strip(): |
| | print("Transcribed query:", query_text) |
| | answer_text = generate_answer(query_text) |
| | print("Generated answer:", answer_text) |
| | tts_audio = text_to_speech(answer_text, target_sample_rate=self.output_sample_rate) |
| | self.output_queue.put_nowait((self.output_sample_rate, tts_audio)) |
| | await asyncio.sleep(0.1) |
| |
|
| | async def receive(self, frame: tuple[int, np.ndarray]) -> None: |
| | sample_rate, audio_array = frame |
| | audio_bytes = audio_array.tobytes() |
| | await self.input_queue.put(audio_bytes) |
| |
|
| | async def emit(self) -> tuple[int, np.ndarray] | None: |
| | return await wait_for_item(self.output_queue) |
| |
|
| | def shutdown(self) -> None: |
| | self.quit.set() |
| |
|
| | |
| | |
| | |
| |
|
| | rtc_config = { |
| | "iceServers": [ |
| | {"urls": "stun:stun.l.google.com:19302"}, |
| | { |
| | "urls": "turn:turn.anyfirewall.com:443?transport=tcp", |
| | "username": "webrtc", |
| | "credential": "webrtc" |
| | } |
| | ] |
| | } |
| |
|
| | stream = Stream( |
| | modality="audio", |
| | mode="send-receive", |
| | handler=RAGVoiceHandler(), |
| | rtc_configuration=rtc_config, |
| | concurrency_limit=5, |
| | time_limit=90, |
| | ) |
| |
|
| | class InputData(BaseModel): |
| | webrtc_id: str |
| |
|
| | app = FastAPI() |
| | stream.mount(app) |
| |
|
| | @app.post("/input_hook") |
| | async def input_hook(body: InputData): |
| | stream.set_input(body.webrtc_id) |
| | return {"status": "ok"} |
| |
|
| | @app.post("/webrtc/offer") |
| | async def webrtc_offer(offer: dict): |
| | return await stream.handle_offer(offer) |
| |
|
| | @app.post("/chat") |
| | async def chat_endpoint(payload: dict): |
| | question = payload.get("question", "") |
| | if not question: |
| | return {"error": "No question provided"} |
| | answer = generate_answer(question) |
| | return {"answer": answer} |
| |
|
| | @app.get("/") |
| | async def index_endpoint(): |
| | index_path = current_dir / "index.html" |
| | html_content = index_path.read_text() |
| | return HTMLResponse(content=html_content) |
| |
|
| | |
| | |
| | |
| |
|
| | if __name__ == "__main__": |
| | mode = os.getenv("MODE", "PHONE") |
| | if mode == "UI": |
| | import gradio as gr |
| | def gradio_chat(user_input): |
| | return generate_answer(user_input) |
| | iface = gr.Interface(fn=gradio_chat, inputs="text", outputs="text", title="Customer Support Chatbot") |
| | iface.launch(server_port=7860) |
| | elif mode == "PHONE": |
| | uvicorn.run(app, host="0.0.0.0", port=7860) |
| | else: |
| | uvicorn.run(app, host="0.0.0.0", port=7860) |
| |
|