HyDE / app.py
TrishaThanmai's picture
update app.py
9320525 verified
raw
history blame
4.4 kB
import os
import gradio as gr
import fitz
import docx
import faiss
import numpy as np
import torch
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
from gtts import gTTS
from huggingface_hub import login
# =============================
# 1) Config
# =============================
HF_TOKEN = os.getenv("HF_TOKEN")
if not HF_TOKEN:
raise ValueError("Please set HF_TOKEN in Space secrets")
login(HF_TOKEN)
EMBED_MODEL_ID = "sentence-transformers/all-MiniLM-L6-v2"
LLM_MODEL_ID = "google/flan-t5-base"
ASR_MODEL_ID = "openai/whisper-small"
# =============================
# 2) Load Models (cached)
# =============================
embedding_model = SentenceTransformer(EMBED_MODEL_ID)
tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_ID)
llm = AutoModelForSeq2SeqLM.from_pretrained(LLM_MODEL_ID)
stt_model = pipeline(
"automatic-speech-recognition",
model=ASR_MODEL_ID,
token=HF_TOKEN
)
# =============================
# 3) Text Extraction
# =============================
def extract_text(file_path: str) -> str:
if not file_path:
return ""
text = ""
ext = os.path.splitext(file_path)[1].lower()
try:
if ext == ".pdf":
doc = fitz.open(file_path)
for page in doc:
text += page.get_text()
elif ext == ".docx":
doc = docx.Document(file_path)
for p in doc.paragraphs:
text += p.text + "\n"
else:
with open(file_path, "r", errors="ignore") as f:
text = f.read()
except Exception:
return ""
return text.strip()
# =============================
# 4) Build FAISS Index
# =============================
def build_faiss(text, chunk_size=500, overlap=50):
if not text:
return None, None
chunks = []
step = chunk_size - overlap
for i in range(0, len(text), step):
chunk = text[i:i + chunk_size].strip()
if chunk:
chunks.append(chunk)
if not chunks:
return None, None
embeds = embedding_model.encode(
chunks,
convert_to_numpy=True,
normalize_embeddings=True
)
index = faiss.IndexFlatIP(embeds.shape[1])
index.add(embeds)
return index, chunks
# =============================
# 5) Globals
# =============================
doc_index = None
doc_chunks = None
# =============================
# 6) Handlers
# =============================
def upload_file(file_path):
global doc_index, doc_chunks
text = extract_text(file_path)
if not text:
return "❌ No readable text found."
idx, chunks = build_faiss(text)
if idx is None:
return "❌ Indexing failed."
doc_index, doc_chunks = idx, chunks
return f"✅ Indexed {len(chunks)} chunks."
def answer_query(query):
if not query.strip():
return "⚠️ Enter a question."
if doc_index is None:
return "⚠️ Upload a document first."
q_vec = embedding_model.encode(
[query],
convert_to_numpy=True,
normalize_embeddings=True
)
_, I = doc_index.search(q_vec, k=5)
context = "\n".join(doc_chunks[i] for i in I[0])
prompt = f"""
Answer using only the context below.
If not found, say "Not in document".
Context:
{context}
Question:
{query}
"""
inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
outputs = llm.generate(**inputs, max_new_tokens=200)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
def voice_query(audio_path):
if not audio_path:
return "", "", None
speech = stt_model(audio_path)["text"]
answer = answer_query(speech)
tts = gTTS(answer)
tts.save("reply.mp3")
return speech, answer, "reply.mp3"
# =============================
# 7) UI
# =============================
with gr.Blocks() as demo:
gr.Markdown("# 📚 RAG Chatbot with Voice")
file = gr.File(type="filepath")
status = gr.Textbox()
gr.Button("Index").click(upload_file, file, status)
query = gr.Textbox(label="Question")
answer = gr.Textbox()
gr.Button("Ask").click(answer_query, query, answer)
audio = gr.Audio(type="filepath")
rec = gr.Textbox()
v_ans = gr.Textbox()
v_audio = gr.Audio()
audio.change(voice_query, audio, [rec, v_ans, v_audio])
demo.launch()