RAG_CHATBOT / app.py
Rishitha3's picture
Update app.py
fe0a12e verified
import os
import gradio as gr
import fitz # PyMuPDF for PDFs
import docx
import faiss
import numpy as np
from sentence_transformers import SentenceTransformer
from transformers import pipeline
from gtts import gTTS # βœ… gTTS for speech
# =============================
# 1) Auth & Config
# =============================
HF_TOKEN = os.getenv("HF_TOKEN")
if HF_TOKEN is None:
raise ValueError("⚠️ Please set your HF_TOKEN as an environment variable.")
EMBED_MODEL_ID = "sentence-transformers/all-MiniLM-L6-v2"
LLM_MODEL_ID = "meta-llama/Llama-3.2-1b-instruct" # βœ… you can swap with smaller model for more speed
ASR_MODEL_ID = "openai/whisper-small"
# =============================
# 2) Load Models
# =============================
# Embeddings
embedding_model = SentenceTransformer(EMBED_MODEL_ID)
# LLM (no HyDE, just final answers)
qa_model = pipeline(
"text-generation",
model=LLM_MODEL_ID,
token=HF_TOKEN,
device_map="auto"
)
# Speech-to-Text
stt_model = pipeline(
"automatic-speech-recognition",
model=ASR_MODEL_ID,
token=HF_TOKEN
)
# =============================
# 3) File Text Extraction
# =============================
def extract_text(file_path: str) -> str:
if not file_path:
return ""
_, ext = os.path.splitext(file_path.lower())
text = ""
if ext == ".pdf":
doc = fitz.open(file_path)
for page in doc:
text += page.get_text("text")
elif ext == ".docx":
doc = docx.Document(file_path)
for para in doc.paragraphs:
text += para.text + "\n"
else:
with open(file_path, "rb") as f:
text = f.read().decode("utf-8", errors="ignore")
return text
# =============================
# 4) Build FAISS Index
# =============================
def build_faiss(text: str, chunk_size=500, overlap=50):
if not text.strip():
return None, None
chunks = []
step = max(1, chunk_size - overlap)
for i in range(0, len(text), step):
chunk = text[i:i + chunk_size]
if chunk.strip():
chunks.append(chunk)
if not chunks:
return None, None
embeddings = embedding_model.encode(chunks, convert_to_numpy=True, normalize_embeddings=True)
dim = embeddings.shape[1]
index = faiss.IndexFlatIP(dim)
index.add(embeddings)
return index, chunks
# =============================
# 5) Globals
# =============================
doc_index = None
doc_chunks = None
# =============================
# 6) Handlers
# =============================
def upload_file(file_path: str):
global doc_index, doc_chunks
if not file_path:
return "⚠️ Please upload a file first."
text = extract_text(file_path)
idx, chunks = build_faiss(text)
if idx is None:
return "⚠️ Could not index: file appears empty."
doc_index, doc_chunks = idx, chunks
return f"βœ… Document indexed! {len(chunks)} chunks ready."
def answer_query(query: str):
global doc_index, doc_chunks
if not query or not query.strip():
return "⚠️ Please enter a question."
if doc_index is None or not doc_chunks:
return "⚠️ Please upload and index a document first."
# Embed query directly
q_vec = embedding_model.encode([query], convert_to_numpy=True, normalize_embeddings=True)
D, I = doc_index.search(q_vec, k=min(5, len(doc_chunks)))
retrieved = [doc_chunks[i] for i in I[0] if 0 <= i < len(doc_chunks)]
context = "\n\n".join(retrieved)
final_prompt = (
"You are a helpful assistant. Answer based only on the context. "
"If the answer is not in the context, say you don't know.\n\n"
f"Context:\n{context}\n\nQuestion: {query}\nAnswer:"
)
out = qa_model(final_prompt, max_new_tokens=200, do_sample=False)[0]["generated_text"]
return out
def synthesize_with_gtts(text: str, out_path="out.mp3"):
"""Generate speech from text and save to mp3 using gTTS."""
tts = gTTS(text=text, lang="en")
tts.save(out_path)
return out_path
def voice_query(audio_path: str):
if not audio_path:
return "⚠️ Please record your question.", "", None
# 1) Speech -> Text
asr = stt_model(audio_path)
recognized = asr.get("text", "").strip()
if not recognized:
return "⚠️ Could not transcribe audio.", "", None
# 2) RAG Answer
ans = answer_query(recognized)
# 3) Text -> Speech (gTTS saves mp3 file)
mp3_path = synthesize_with_gtts(ans, "answer.mp3")
return recognized, ans, mp3_path
# =============================
# 7) Gradio UI
# =============================
with gr.Blocks(theme=gr.themes.Soft(primary_hue="indigo", secondary_hue="cyan")) as demo:
gr.Markdown("# πŸ“š Simple RAG Chatbot + 🎀 Voice")
gr.Markdown("Upload a PDF/DOCX/TXT and ask by typing **or** speaking. Uses Whisper for ASR and gTTS for speech output.")
with gr.Row():
with gr.Column(scale=1):
file_input = gr.File(label="πŸ“‚ Upload Document", type="filepath")
upload_btn = gr.Button("⚑ Index Document", variant="primary")
status = gr.Textbox(label="Status", interactive=False)
with gr.Column(scale=2):
gr.Markdown("### ✍️ Text Chat")
query = gr.Textbox(label="❓ Ask a Question", placeholder="e.g., What are the key points?")
ask_btn = gr.Button("πŸš€ Get Answer", variant="primary")
answer = gr.Textbox(label="πŸ’‘ Answer", lines=8)
gr.Markdown("### 🎀 Voice Chat")
mic_input = gr.Audio(sources=["microphone"], type="filepath", label="Speak your question")
rec_text = gr.Textbox(label="πŸ“ Recognized Speech", interactive=False)
v_answer = gr.Textbox(label="πŸ’‘ Answer (from voice)", lines=8)
v_audio = gr.Audio(label="πŸ”Š Bot Voice Reply")
# Bind events
upload_btn.click(fn=upload_file, inputs=file_input, outputs=status)
ask_btn.click(fn=answer_query, inputs=query, outputs=answer)
mic_input.change(fn=voice_query, inputs=mic_input, outputs=[rec_text, v_answer, v_audio])
demo.launch()