sinhala-chatbot / app /rag.py
CHAMATH
Deploy Space with optional ASR mode
464b72a
import os
import re
import unicodedata
from pathlib import Path
from typing import List
from dotenv import load_dotenv
import google.generativeai as genai
from huggingface_hub import InferenceClient
load_dotenv()
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
if GEMINI_API_KEY:
genai.configure(api_key=GEMINI_API_KEY)
vectordb = None
retriever = None
embeddings = None
rag_initialized = False
uploaded_documents = []
last_index_mtime = None
RAG_DATA_DIR = Path(__file__).resolve().parent.parent / "rag_data"
FAISS_INDEX_PATH = RAG_DATA_DIR / "faiss_index"
INSUFFICIENT_CONTEXT_MARKER = "i don't have enough information in the documents"
def initialize_embeddings():
"""Initialize the multilingual embedding model."""
global embeddings
if embeddings is not None:
return embeddings
print("Loading multilingual embedding model...")
from langchain_huggingface import HuggingFaceEmbeddings
embeddings = HuggingFaceEmbeddings(
model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2",
encode_kwargs={"normalize_embeddings": True},
)
print("Embedding model loaded.")
return embeddings
def clean_text(text: str) -> str:
"""Clean and normalize text for embedding."""
if not isinstance(text, str) or not text.strip():
return ""
normalized_text = unicodedata.normalize("NFKC", text)
cleaned_chars = [
char for char in normalized_text
if unicodedata.category(char) not in ["So", "Cn", "Cc", "Cf", "Cs"]
]
cleaned_text = "".join(cleaned_chars)
cleaned_text = re.sub(r"\s+", " ", cleaned_text).strip()
return cleaned_text
def load_and_process_pdf(pdf_path: str) -> List[dict]:
"""Load a PDF and split it into chunks."""
from langchain_community.document_loaders import PyPDFLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
print(f"Loading PDF: {pdf_path}")
loader = PyPDFLoader(pdf_path)
docs = loader.load()
splitter = RecursiveCharacterTextSplitter(
chunk_size=300,
chunk_overlap=80,
)
chunks = splitter.split_documents(docs)
print(f"Loaded {len(docs)} pages, created {len(chunks)} chunks.")
return chunks
def create_vector_store(chunks: List) -> bool:
"""Create or update the FAISS vector store with document chunks."""
global vectordb, retriever, rag_initialized
from langchain_community.vectorstores import FAISS
initialize_embeddings()
texts = [doc.page_content for doc in chunks]
metadatas = [doc.metadata for doc in chunks]
processed_texts = []
processed_metadatas = []
for i, text in enumerate(texts):
cleaned_text = clean_text(text)
if cleaned_text:
processed_texts.append(cleaned_text)
processed_metadatas.append(metadatas[i])
if not processed_texts:
print("No valid texts after cleaning.")
return False
print(f"Processing {len(processed_texts)} text chunks for embedding...")
if vectordb is None:
vectordb = FAISS.from_texts(processed_texts, embeddings, metadatas=processed_metadatas)
else:
new_vectordb = FAISS.from_texts(processed_texts, embeddings, metadatas=processed_metadatas)
vectordb.merge_from(new_vectordb)
retriever = vectordb.as_retriever(search_kwargs={"k": 4})
rag_initialized = True
save_vector_store()
_sync_uploaded_documents()
print("Vector store created/updated successfully.")
return True
def save_vector_store():
"""Save the FAISS index to disk."""
global vectordb, last_index_mtime
if vectordb is None:
return
RAG_DATA_DIR.mkdir(parents=True, exist_ok=True)
vectordb.save_local(str(FAISS_INDEX_PATH))
last_index_mtime = _get_index_mtime()
print(f"Vector store saved to {FAISS_INDEX_PATH}.")
def load_vector_store() -> bool:
"""Load the FAISS index from disk if it exists."""
global vectordb, retriever, rag_initialized, last_index_mtime
if not FAISS_INDEX_PATH.exists():
return False
try:
from langchain_community.vectorstores import FAISS
initialize_embeddings()
vectordb = FAISS.load_local(
str(FAISS_INDEX_PATH),
embeddings,
allow_dangerous_deserialization=True,
)
retriever = vectordb.as_retriever(search_kwargs={"k": 4})
rag_initialized = True
last_index_mtime = _get_index_mtime()
_sync_uploaded_documents()
print("Loaded existing vector store from disk.")
return True
except Exception as e:
print(f"Failed to load vector store: {e}")
return False
def rag_answer(question: str) -> dict:
"""Answer a question using RAG - first check database, then fallback to Gemini/HF."""
global retriever, vectordb, last_index_mtime
result = {
"answer": "",
"source": "none",
"context_found": False,
"relevance_score": 0.0,
}
if FAISS_INDEX_PATH.exists():
current_mtime = _get_index_mtime()
if (not rag_initialized or retriever is None) or (
current_mtime and last_index_mtime and current_mtime > last_index_mtime
):
load_vector_store()
if not rag_initialized or retriever is None:
result["source"] = "gemini"
result["answer"] = _ask_gemini_directly(question)
return result
docs_with_scores = vectordb.similarity_search_with_score(question, k=4)
if not docs_with_scores:
print(f"No documents found for question: {question}")
result["source"] = "gemini"
result["answer"] = _ask_gemini_directly(question)
return result
best_score = docs_with_scores[0][1] if docs_with_scores else float("inf")
result["relevance_score"] = float(best_score)
print(f"\nQuestion: {question}")
print(f"Retrieved {len(docs_with_scores)} documents:")
for i, (doc, score) in enumerate(docs_with_scores):
preview = doc.page_content[:100].replace("\n", " ")
print(f" [{i + 1}] Score: {score:.3f} - {preview}...")
print(f"Using RAG with relevance score: {best_score}")
docs = [doc for doc, score in docs_with_scores]
context = "\n\n".join([d.page_content for d in docs])
result["context_found"] = True
prompt = (
"You are a helpful assistant. Answer the question based ONLY on the following "
"context from the PDF document. If the context doesn't contain enough information "
"to answer the question, say \"I don't have enough information in the documents to "
"answer this question.\"\n\n"
"Context from PDF:\n"
f"{context}\n\n"
f"Question: {question}\n\n"
"Answer (in English):"
)
try:
gemini_key = os.getenv("GEMINI_API_KEY")
if gemini_key:
try:
model = genai.GenerativeModel("models/gemini-2.5-flash")
response = model.generate_content(prompt)
rag_answer_text = (response.text or "").strip()
if _is_insufficient_context_answer(rag_answer_text):
print("RAG context not sufficient. Falling back to direct AI answer.")
result["answer"] = _ask_gemini_directly(question)
result["source"] = "gemini"
return result
result["answer"] = rag_answer_text
result["source"] = "rag"
return result
except Exception as gemini_error:
error_msg = str(gemini_error)
print(f"Gemini error in RAG: {error_msg[:200]}...")
if "429" in error_msg or "quota" in error_msg.lower():
print("Gemini quota exceeded. Using Hugging Face for RAG.")
print("Using Hugging Face for RAG answer...")
rag_answer_text = _ask_huggingface_free(prompt).strip()
if _is_insufficient_context_answer(rag_answer_text):
print("RAG context not sufficient. Falling back to direct AI answer.")
result["answer"] = _ask_gemini_directly(question)
result["source"] = "gemini"
return result
result["answer"] = rag_answer_text
result["source"] = "rag"
except Exception as e:
print(f"All RAG generation failed: {e}")
result["answer"] = "Sorry, unable to generate answer. Please try again later."
result["source"] = "error"
return result
def _ask_huggingface_free(prompt: str) -> str:
"""Use free Hugging Face Inference API with token if available."""
hf_token = os.getenv("HF_API_TOKEN")
try:
client = InferenceClient(token=hf_token)
except Exception as e:
raise Exception(f"Failed to create Hugging Face client: {e}")
messages = [{"role": "user", "content": prompt}]
try:
print("Calling Hugging Face API (Qwen2.5-72B-Instruct)...")
response = client.chat_completion(
messages=messages,
model="Qwen/Qwen2.5-72B-Instruct",
max_tokens=500,
temperature=0.7,
)
return response.choices[0].message.content
except Exception as e:
error_str = str(e)
print(f"Hugging Face primary model error: {e}")
try:
print("Trying backup model (Microsoft Phi-3)...")
response = client.chat_completion(
messages=messages,
model="microsoft/Phi-3-mini-4k-instruct",
max_tokens=500,
temperature=0.7,
)
return response.choices[0].message.content
except Exception as e2:
print(f"Backup model also failed: {e2}")
raise Exception(f"All HF models failed: {error_str}")
def _ask_gemini_directly(question: str) -> str:
"""Fallback: Ask Gemini directly without RAG context, with Hugging Face fallback."""
prompt = (
"Answer the following question helpfully and accurately:\n\n"
f"Question: {question}\n\n"
"Answer:"
)
gemini_key = os.getenv("GEMINI_API_KEY")
if gemini_key:
try:
model = genai.GenerativeModel("models/gemini-2.5-flash")
response = model.generate_content(prompt)
return response.text
except Exception as gemini_error:
error_msg = str(gemini_error)
print(f"Gemini API error: {error_msg[:200]}...")
if "429" in error_msg or "quota" in error_msg.lower():
print("Gemini quota exceeded. Switching to Hugging Face.")
else:
print("Gemini error. Switching to Hugging Face.")
else:
print("No Gemini API key, using Hugging Face.")
try:
print("Using Hugging Face for direct answer...")
return _ask_huggingface_free(prompt)
except Exception as hf_error:
print(f"Hugging Face error: {hf_error}")
return (
"Sorry, both AI services are unavailable. "
f"Gemini quota exceeded, and Hugging Face error: {str(hf_error)}"
)
def get_rag_status() -> dict:
"""Get the current status of the RAG system."""
if not rag_initialized and FAISS_INDEX_PATH.exists():
load_vector_store()
_sync_uploaded_documents()
return {
"initialized": rag_initialized,
"documents_count": len(uploaded_documents),
"documents": uploaded_documents,
"has_embeddings": embeddings is not None,
"has_vector_store": vectordb is not None,
}
def clear_rag_data():
"""Clear all RAG data."""
global vectordb, retriever, rag_initialized, uploaded_documents, last_index_mtime
vectordb = None
retriever = None
rag_initialized = False
uploaded_documents = []
last_index_mtime = None
if FAISS_INDEX_PATH.exists():
import shutil
shutil.rmtree(FAISS_INDEX_PATH)
print("RAG data cleared.")
return True
def _get_index_mtime():
index_file = FAISS_INDEX_PATH / "index.faiss"
if index_file.exists():
return index_file.stat().st_mtime
return None
def _is_insufficient_context_answer(answer_text: str) -> bool:
if not answer_text:
return True
normalized = answer_text.strip().lower()
return INSUFFICIENT_CONTEXT_MARKER in normalized
def _sync_uploaded_documents():
global uploaded_documents
if not RAG_DATA_DIR.exists():
uploaded_documents = []
return
uploaded_documents = sorted(
[pdf.name for pdf in RAG_DATA_DIR.glob("*.pdf") if pdf.is_file()]
)
def rebuild_vector_store_from_pdfs() -> bool:
"""Rebuild vector store from all PDFs in rag_data directory."""
global vectordb, retriever, rag_initialized
_sync_uploaded_documents()
if not uploaded_documents:
print("No PDFs found in rag_data to rebuild vector store.")
return False
initialize_embeddings()
vectordb = None
retriever = None
rag_initialized = False
all_chunks = []
for filename in uploaded_documents:
pdf_path = RAG_DATA_DIR / filename
try:
chunks = load_and_process_pdf(str(pdf_path))
all_chunks.extend(chunks)
except Exception as e:
print(f"Skipping PDF '{filename}' due to processing error: {e}")
if not all_chunks:
print("No chunks generated from PDFs. Rebuild aborted.")
return False
success = create_vector_store(all_chunks)
if success:
print(f"Rebuilt vector store from {len(uploaded_documents)} PDF(s).")
return success